diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 310fc2d4..a5e5fe79 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -10,12 +10,10 @@ jobs: runs-on: ubuntu-latest # container actions require GNU/Linux strategy: matrix: - coq_container: -# - coqorg/coq:8.12.2 -# - coqorg/coq:8.16.1-ocaml-4.13.1-flambda - - coqorg/coq:8.18.0-ocaml-4.13.1-flambda + rocq_container: + - rocq/rocq-prover:9.0.1-ocaml-4.14.2-flambda container: - image: ${{ matrix.coq_container }} + image: ${{ matrix.rocq_container }} options: --user root steps: - uses: actions/checkout@v4 @@ -26,31 +24,8 @@ jobs: - name: ls run: ls -la . - name: Install Opam dependencies - run: su coq -c 'eval $(opam env) && opam install --deps-only --with-test --with-doc -y -j 2 ./Formal_ML.opam' + run: su rocq -c 'eval $(opam env) && opam install --deps-only --with-test --with-doc -y -j 2 ./Formal_ML.opam' - name: Build using Make - run: su coq -c 'eval $(opam env) && make -kj 2' + run: su rocq -c 'eval $(opam env) && make -kj 2' - name: Build documentation - run: su coq -c 'eval $(opam env) && make -kj 2 doc' - -# - uses: coq-community/docker-coq-action@v1 -# with: -# opam_file: 'Formal_ML.opam' -# coq_version: ${{ matrix.coq_version }} -# ocaml_version: ${{ matrix.ocaml_version }} -# # export: 'OPAMWITHTEST OPAMWITHDOC' -# export: 'OPAMWITHDOC' -# after_script: | -# sudo cp -a $(opam config var Formal_ML:build)/documentation . -# env: -# OPAMWITHDOC: 'true' -# OPAMWITHTEST: 'true' - # - if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - # name: deploy documentation - # uses: JamesIves/github-pages-deploy-action@3.7.1 - # with: - # ACCESS_TOKEN: ${{ secrets.ACCESS_TOKEN }} - # REPOSITORY_NAME: FormalML/FormalML.github.io # the target repository - # TARGET_FOLDER: main/documentation # target directory - # BRANCH: main # The branch the action should deploy to. - # FOLDER: documentation # The folder the action should deploy. - # CLEAN: true # Automatically remove deleted files from the deploy branch + run: su rocq -c 'eval $(opam env) && make -kj 2 doc' diff --git a/.gitignore b/.gitignore index 7a410b78..42bb264a 100644 --- a/.gitignore +++ b/.gitignore @@ -6,9 +6,9 @@ *.aux .coqdeps.d .#* -Makefile.coq -Makefile.coq.conf -.Makefile.coq.d +Makefile.rocq +Makefile.rocq.conf +.Makefile.rocq.d extracted ocaml/_build bin diff --git a/Dockerfile b/Dockerfile index 35a34409..a47e7e2f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,13 +1,13 @@ -ARG coq_image="coqorg/coq:8.12.2" -FROM ${coq_image} +ARG rocq_image="rocq/rocq-prover:9.0.1" +FROM ${rocq_image} MAINTAINER Avi Shinnar "shinnar@us.ibm.com" # needs to be a subdirectory to avoid causing problems with -# the /home/coq/.opam directory (and probably other stuff) -WORKDIR /home/coq +# the /home/rocq/.opam directory (and probably other stuff) +WORKDIR /home/rocq -COPY --chown=coq:coq Formal_ML.opam ./formal_ml/ +COPY --chown=rocq:rocq Formal_ML.opam ./formal_ml/ RUN ["/bin/bash", "--login", "-c", "set -x \ && if [ -n \"${COMPILER_EDGE}\" ]; then opam switch ${COMPILER_EDGE} && eval $(opam env); fi \ @@ -16,13 +16,8 @@ RUN ["/bin/bash", "--login", "-c", "set -x \ && opam clean -a -c -s --logs"] -COPY --chown=coq:coq breast-cancer-wisconsin.data breast-cancer-wisconsin.names ./formal_ml/ -COPY --chown=coq:coq _CoqProject Makefile Makefile.coq_modules ./formal_ml/ -COPY --chown=coq:coq coq ./formal_ml/coq -COPY --chown=coq:coq ocaml ./formal_ml/ocaml +COPY --chown=rocq:rocq _RocqProject Makefile Makefile.rocq_modules ./formal_ml/ +COPY --chown=rocq:rocq rocq ./formal_ml/rocq RUN ["/bin/bash", "--login", "-c", "set -x && cd formal_ml && \ make && make doc"] - -# CMD ["/bin/bash", "--login", "-c", "set -x && cd formal_ml && \ -# make test"] \ No newline at end of file diff --git a/Formal_ML.opam b/Formal_ML.opam index 8d26190f..0c08b859 100644 --- a/Formal_ML.opam +++ b/Formal_ML.opam @@ -9,25 +9,15 @@ homepage: "https://github.com/ibm/formalml" bug-reports: "https://github.com/ibm/formalml/issues" depends: [ "ocaml" {>= "4.07.0"} - "coq" {>= "8.12.1"} - "coq-mathcomp-ssreflect" - "coq-mathcomp-algebra" - "coq-mathcomp-algebra-tactics" - "coq-mathcomp-real-closed" - "coq-mathcomp-analysis" {< "1.0.0"} - "coq-coquelicot" {= "3.3.1" } - "coq-flocq" {>= "4.0.0" } - "coq-interval" {>= "4.8.0"} + "rocq-core" {>= "9.0.0"} + "rocq-stdlib" + "rocq-mathcomp-ssreflect" + "coq-coquelicot" "coq-ext-lib" {<= "1.0.0"} - "ocamlbuild" - "base64" - "menhir" - "csv" "coq-coq2html" {with-doc} ] build: [[make] [make "doc"] {with-doc} - [make "test"] {with-test} ] install: [make] dev-repo: "git+https://github.com/IBM/FormalML.git" diff --git a/Makefile b/Makefile index 0f60abbf..a11ded7d 100644 --- a/Makefile +++ b/Makefile @@ -1,38 +1,29 @@ -# Contains the list of all the Coq modules -include Makefile.coq_modules +# Contains the list of all the Rocq modules +include Makefile.rocq_modules -COQ_FILES = $(addprefix coq/,$(MODULES:%=%.v)) +ROCQ_FILES = $(addprefix rocq/,$(MODULES:%=%.v)) -all: coq # ocaml +all: rocq -coq: Makefile.coq - @$(MAKE) -f Makefile.coq +rocq: Makefile.rocq + @$(MAKE) -f Makefile.rocq -Makefile.coq: Makefile Makefile.coq_modules $(COQ_FILES) - @coq_makefile -f _CoqProject $(COQ_FILES) -o Makefile.coq +Makefile.rocq: Makefile Makefile.rocq_modules $(ROCQ_FILES) + @rocq makefile -f _RocqProject $(ROCQ_FILES) -o Makefile.rocq -ocaml: coq - @$(MAKE) -C ocaml native +clean-rocq: + - @$(MAKE) -f Makefile.rocq clean -clean-coq: - - @$(MAKE) -f Makefile.coq clean -clean-ocaml: - @$(MAKE) -C ocaml clean - - -COQ_FILES_FOR_DOC = $(MODULES:%=%.v) +ROCQ_FILES_FOR_DOC = $(MODULES:%=%.v) GLOB_FILES_FOR_DOC = $(MODULES:%=%.glob) -doc: coq +doc: rocq mkdir -p documentation/html rm -f documentation/html/*.html - cd coq && coq2html -d ../documentation/html -base FormalML -external http://coquelicot.saclay.inria.fr/html/ Coquelicot $(COQ_FILES_FOR_DOC) $(GLOB_FILES_FOR_DOC) - -test: coq ocaml - ./bin/nnopt + cd rocq && coq2html -d ../documentation/html -base FormalML -external http://coquelicot.saclay.inria.fr/html/ Coquelicot $(ROCQ_FILES_FOR_DOC) $(GLOB_FILES_FOR_DOC) -clean: clean-coq clean-ocaml +clean: clean-rocq rm -rf documentation/html -.PHONY: all ocaml clean clean-coq coq test doc +.PHONY: all clean clean-rocq rocq doc diff --git a/Makefile.coq_modules b/Makefile.rocq_modules similarity index 84% rename from Makefile.coq_modules rename to Makefile.rocq_modules index 981360eb..d5891582 100644 --- a/Makefile.coq_modules +++ b/Makefile.rocq_modules @@ -30,7 +30,6 @@ UTILS = BasicUtils \ ClassicUtils \ CoquelicotAdd \ ELim_Seq \ - ExtrFloatishIEEE \ improper_integrals \ Isomorphism \ ListAdd \ @@ -46,24 +45,11 @@ UTILS = BasicUtils \ StreamAdd \ StreamLimits \ Sums \ - nvector \ Vector \ PushNeg \ DVector \ Utils \ - Floatish/FloatishDef \ - Floatish/FloatishOps \ - Floatish/FloatishRealOps \ - Floatish/FloatishInterval \ - Floatish/FloatishIEEE \ - Floatish/FloatishReal \ - Floatish - -NEURAL_NETWORKS = AxiomaticNormedRealVectorSpace \ - DefinedFunctions \ - derivlemmas \ - Gen_NN - + derivlemmas CERTRL = pmf_monad \ qvalues \ @@ -122,20 +108,10 @@ QLEARN = \ Tsitsiklis \ jaakkola_vector -FHE = \ - nth_root \ - encode \ - encrypt \ - zp_prim_root \ - arith - MODULES = $(addprefix lib_utils/,$(QCERT_LIB_UTILS)) \ $(addprefix CertRL/LM/,$(ELFIC_UTILS)) \ $(addprefix utils/,$(UTILS)) \ - $(addprefix NeuralNetworks/,$(NEURAL_NETWORKS)) \ $(addprefix ProbTheory/,$(PROB_THEORY)) \ $(addprefix CertRL/,$(CERTRL)) \ $(addprefix QLearn/,$(QLEARN)) \ - $(addprefix FHE/,$(FHE)) \ - API - + diff --git a/README.md b/README.md index ccbb75e9..04a0e7e7 100644 --- a/README.md +++ b/README.md @@ -8,15 +8,15 @@ ## Getting Started - To compile the Coq code in this repository, + To compile the [Rocq](https://rocq-prover.org/) (previously known as Coq) code in this repository, [Install Rocq](https://rocq-prover.org/install). For example: - first install opam [opam (ocaml package manager)](https://opam.ocaml.org/). - - Add support for coq ocaml repositories: `opam repo add coq-released --set-default https://coq.inria.fr/opam/released`. - - If you want to create a local environment (switch), you can run `opam switch create nnsopt 4.07.0`. - - Next, run `opam install . --deps-only`. This should install all the dependencies needed, including Coq. + - Add support for rocq ocaml repositories: `opam repo add rocq-released https://rocq-prover.org/opam/released` + - If you want to create a local environment (switch), you can run `opam switch create formalml 4.14.2`. + - Next, run `opam install . --deps-only`. This should install all the dependencies needed, including Rocq. - Once the prerequisites are installed, run `make` to compile it. - Alternatively, the included Docker file can be built using Docker to compile the coq code in a suitable environment. - `docker build --build-arg=coq_image="coqorg/coq:8.8.2" --pull -t nn_sopt .` + Alternatively, the included Docker file can be built using Docker to compile the rocq code in a suitable environment. + `docker build --pull -t formalml .` ## License This repository is distributed under the terms of the Apache 2.0 License, see LICENSE.txt. diff --git a/_CoqProject b/_CoqProject deleted file mode 100644 index 101bc735..00000000 --- a/_CoqProject +++ /dev/null @@ -1 +0,0 @@ --R coq FormalML -arg -set -arg "Warnings=+default,-ambiguous-path,-coercions,-hiding-delimiting-key,-overwriting-delimiting-key,-redundant-canonical-projection,-typechecker,-ssr-search-moved,-deprecated,-notation-overridden" \ No newline at end of file diff --git a/_RocqProject b/_RocqProject new file mode 100644 index 00000000..cf77dfa1 --- /dev/null +++ b/_RocqProject @@ -0,0 +1 @@ +-R rocq FormalML -arg -set -arg "Warnings=+default,-ambiguous-path,-coercions,-hiding-delimiting-key,-overwriting-delimiting-key,-redundant-canonical-projection,-typechecker,-ssr-search-moved,-deprecated,-notation-overridden" \ No newline at end of file diff --git a/breast-cancer-wisconsin.data b/breast-cancer-wisconsin.data deleted file mode 100644 index 9e329d00..00000000 --- a/breast-cancer-wisconsin.data +++ /dev/null @@ -1,683 +0,0 @@ -1000025,5,1,1,1,2,1,3,1,1,2 -1002945,5,4,4,5,7,10,3,2,1,2 -1015425,3,1,1,1,2,2,3,1,1,2 -1016277,6,8,8,1,3,4,3,7,1,2 -1017023,4,1,1,3,2,1,3,1,1,2 -1017122,8,10,10,8,7,10,9,7,1,4 -1018099,1,1,1,1,2,10,3,1,1,2 -1018561,2,1,2,1,2,1,3,1,1,2 -1033078,2,1,1,1,2,1,1,1,5,2 -1033078,4,2,1,1,2,1,2,1,1,2 -1035283,1,1,1,1,1,1,3,1,1,2 -1036172,2,1,1,1,2,1,2,1,1,2 -1041801,5,3,3,3,2,3,4,4,1,4 -1043999,1,1,1,1,2,3,3,1,1,2 -1044572,8,7,5,10,7,9,5,5,4,4 -1047630,7,4,6,4,6,1,4,3,1,4 -1048672,4,1,1,1,2,1,2,1,1,2 -1049815,4,1,1,1,2,1,3,1,1,2 -1050670,10,7,7,6,4,10,4,1,2,4 -1050718,6,1,1,1,2,1,3,1,1,2 -1054590,7,3,2,10,5,10,5,4,4,4 -1054593,10,5,5,3,6,7,7,10,1,4 -1056784,3,1,1,1,2,1,2,1,1,2 -1059552,1,1,1,1,2,1,3,1,1,2 -1065726,5,2,3,4,2,7,3,6,1,4 -1066373,3,2,1,1,1,1,2,1,1,2 -1066979,5,1,1,1,2,1,2,1,1,2 -1067444,2,1,1,1,2,1,2,1,1,2 -1070935,1,1,3,1,2,1,1,1,1,2 -1070935,3,1,1,1,1,1,2,1,1,2 -1071760,2,1,1,1,2,1,3,1,1,2 -1072179,10,7,7,3,8,5,7,4,3,4 -1074610,2,1,1,2,2,1,3,1,1,2 -1075123,3,1,2,1,2,1,2,1,1,2 -1079304,2,1,1,1,2,1,2,1,1,2 -1080185,10,10,10,8,6,1,8,9,1,4 -1081791,6,2,1,1,1,1,7,1,1,2 -1084584,5,4,4,9,2,10,5,6,1,4 -1091262,2,5,3,3,6,7,7,5,1,4 -1099510,10,4,3,1,3,3,6,5,2,4 -1100524,6,10,10,2,8,10,7,3,3,4 -1102573,5,6,5,6,10,1,3,1,1,4 -1103608,10,10,10,4,8,1,8,10,1,4 -1103722,1,1,1,1,2,1,2,1,2,2 -1105257,3,7,7,4,4,9,4,8,1,4 -1105524,1,1,1,1,2,1,2,1,1,2 -1106095,4,1,1,3,2,1,3,1,1,2 -1106829,7,8,7,2,4,8,3,8,2,4 -1108370,9,5,8,1,2,3,2,1,5,4 -1108449,5,3,3,4,2,4,3,4,1,4 -1110102,10,3,6,2,3,5,4,10,2,4 -1110503,5,5,5,8,10,8,7,3,7,4 -1110524,10,5,5,6,8,8,7,1,1,4 -1111249,10,6,6,3,4,5,3,6,1,4 -1112209,8,10,10,1,3,6,3,9,1,4 -1113038,8,2,4,1,5,1,5,4,4,4 -1113483,5,2,3,1,6,10,5,1,1,4 -1113906,9,5,5,2,2,2,5,1,1,4 -1115282,5,3,5,5,3,3,4,10,1,4 -1115293,1,1,1,1,2,2,2,1,1,2 -1116116,9,10,10,1,10,8,3,3,1,4 -1116132,6,3,4,1,5,2,3,9,1,4 -1116192,1,1,1,1,2,1,2,1,1,2 -1116998,10,4,2,1,3,2,4,3,10,4 -1117152,4,1,1,1,2,1,3,1,1,2 -1118039,5,3,4,1,8,10,4,9,1,4 -1120559,8,3,8,3,4,9,8,9,8,4 -1121732,1,1,1,1,2,1,3,2,1,2 -1121919,5,1,3,1,2,1,2,1,1,2 -1123061,6,10,2,8,10,2,7,8,10,4 -1124651,1,3,3,2,2,1,7,2,1,2 -1125035,9,4,5,10,6,10,4,8,1,4 -1126417,10,6,4,1,3,4,3,2,3,4 -1131294,1,1,2,1,2,2,4,2,1,2 -1132347,1,1,4,1,2,1,2,1,1,2 -1133041,5,3,1,2,2,1,2,1,1,2 -1133136,3,1,1,1,2,3,3,1,1,2 -1136142,2,1,1,1,3,1,2,1,1,2 -1137156,2,2,2,1,1,1,7,1,1,2 -1143978,4,1,1,2,2,1,2,1,1,2 -1143978,5,2,1,1,2,1,3,1,1,2 -1147044,3,1,1,1,2,2,7,1,1,2 -1147699,3,5,7,8,8,9,7,10,7,4 -1147748,5,10,6,1,10,4,4,10,10,4 -1148278,3,3,6,4,5,8,4,4,1,4 -1148873,3,6,6,6,5,10,6,8,3,4 -1152331,4,1,1,1,2,1,3,1,1,2 -1155546,2,1,1,2,3,1,2,1,1,2 -1156272,1,1,1,1,2,1,3,1,1,2 -1156948,3,1,1,2,2,1,1,1,1,2 -1157734,4,1,1,1,2,1,3,1,1,2 -1158247,1,1,1,1,2,1,2,1,1,2 -1160476,2,1,1,1,2,1,3,1,1,2 -1164066,1,1,1,1,2,1,3,1,1,2 -1165297,2,1,1,2,2,1,1,1,1,2 -1165790,5,1,1,1,2,1,3,1,1,2 -1165926,9,6,9,2,10,6,2,9,10,4 -1166630,7,5,6,10,5,10,7,9,4,4 -1166654,10,3,5,1,10,5,3,10,2,4 -1167439,2,3,4,4,2,5,2,5,1,4 -1167471,4,1,2,1,2,1,3,1,1,2 -1168359,8,2,3,1,6,3,7,1,1,4 -1168736,10,10,10,10,10,1,8,8,8,4 -1169049,7,3,4,4,3,3,3,2,7,4 -1170419,10,10,10,8,2,10,4,1,1,4 -1170420,1,6,8,10,8,10,5,7,1,4 -1171710,1,1,1,1,2,1,2,3,1,2 -1171710,6,5,4,4,3,9,7,8,3,4 -1171795,1,3,1,2,2,2,5,3,2,2 -1171845,8,6,4,3,5,9,3,1,1,4 -1172152,10,3,3,10,2,10,7,3,3,4 -1173216,10,10,10,3,10,8,8,1,1,4 -1173235,3,3,2,1,2,3,3,1,1,2 -1173347,1,1,1,1,2,5,1,1,1,2 -1173347,8,3,3,1,2,2,3,2,1,2 -1173509,4,5,5,10,4,10,7,5,8,4 -1173514,1,1,1,1,4,3,1,1,1,2 -1173681,3,2,1,1,2,2,3,1,1,2 -1174057,1,1,2,2,2,1,3,1,1,2 -1174057,4,2,1,1,2,2,3,1,1,2 -1174131,10,10,10,2,10,10,5,3,3,4 -1174428,5,3,5,1,8,10,5,3,1,4 -1175937,5,4,6,7,9,7,8,10,1,4 -1176406,1,1,1,1,2,1,2,1,1,2 -1176881,7,5,3,7,4,10,7,5,5,4 -1177027,3,1,1,1,2,1,3,1,1,2 -1177399,8,3,5,4,5,10,1,6,2,4 -1177512,1,1,1,1,10,1,1,1,1,2 -1178580,5,1,3,1,2,1,2,1,1,2 -1179818,2,1,1,1,2,1,3,1,1,2 -1180194,5,10,8,10,8,10,3,6,3,4 -1180523,3,1,1,1,2,1,2,2,1,2 -1180831,3,1,1,1,3,1,2,1,1,2 -1181356,5,1,1,1,2,2,3,3,1,2 -1182404,4,1,1,1,2,1,2,1,1,2 -1182410,3,1,1,1,2,1,1,1,1,2 -1183240,4,1,2,1,2,1,2,1,1,2 -1183516,3,1,1,1,2,1,1,1,1,2 -1183911,2,1,1,1,2,1,1,1,1,2 -1183983,9,5,5,4,4,5,4,3,3,4 -1184184,1,1,1,1,2,5,1,1,1,2 -1184241,2,1,1,1,2,1,2,1,1,2 -1185609,3,4,5,2,6,8,4,1,1,4 -1185610,1,1,1,1,3,2,2,1,1,2 -1187457,3,1,1,3,8,1,5,8,1,2 -1187805,8,8,7,4,10,10,7,8,7,4 -1188472,1,1,1,1,1,1,3,1,1,2 -1189266,7,2,4,1,6,10,5,4,3,4 -1189286,10,10,8,6,4,5,8,10,1,4 -1190394,4,1,1,1,2,3,1,1,1,2 -1190485,1,1,1,1,2,1,1,1,1,2 -1192325,5,5,5,6,3,10,3,1,1,4 -1193091,1,2,2,1,2,1,2,1,1,2 -1193210,2,1,1,1,2,1,3,1,1,2 -1196295,9,9,10,3,6,10,7,10,6,4 -1196915,10,7,7,4,5,10,5,7,2,4 -1197080,4,1,1,1,2,1,3,2,1,2 -1197270,3,1,1,1,2,1,3,1,1,2 -1197440,1,1,1,2,1,3,1,1,7,2 -1197979,4,1,1,1,2,2,3,2,1,2 -1197993,5,6,7,8,8,10,3,10,3,4 -1198128,10,8,10,10,6,1,3,1,10,4 -1198641,3,1,1,1,2,1,3,1,1,2 -1199219,1,1,1,2,1,1,1,1,1,2 -1199731,3,1,1,1,2,1,1,1,1,2 -1199983,1,1,1,1,2,1,3,1,1,2 -1200772,1,1,1,1,2,1,2,1,1,2 -1200847,6,10,10,10,8,10,10,10,7,4 -1200892,8,6,5,4,3,10,6,1,1,4 -1200952,5,8,7,7,10,10,5,7,1,4 -1201834,2,1,1,1,2,1,3,1,1,2 -1201936,5,10,10,3,8,1,5,10,3,4 -1202125,4,1,1,1,2,1,3,1,1,2 -1202812,5,3,3,3,6,10,3,1,1,4 -1203096,1,1,1,1,1,1,3,1,1,2 -1204242,1,1,1,1,2,1,1,1,1,2 -1204898,6,1,1,1,2,1,3,1,1,2 -1205138,5,8,8,8,5,10,7,8,1,4 -1205579,8,7,6,4,4,10,5,1,1,4 -1206089,2,1,1,1,1,1,3,1,1,2 -1206695,1,5,8,6,5,8,7,10,1,4 -1206841,10,5,6,10,6,10,7,7,10,4 -1207986,5,8,4,10,5,8,9,10,1,4 -1208301,1,2,3,1,2,1,3,1,1,2 -1210963,10,10,10,8,6,8,7,10,1,4 -1211202,7,5,10,10,10,10,4,10,3,4 -1212232,5,1,1,1,2,1,2,1,1,2 -1212251,1,1,1,1,2,1,3,1,1,2 -1212422,3,1,1,1,2,1,3,1,1,2 -1212422,4,1,1,1,2,1,3,1,1,2 -1213375,8,4,4,5,4,7,7,8,2,2 -1213383,5,1,1,4,2,1,3,1,1,2 -1214092,1,1,1,1,2,1,1,1,1,2 -1214556,3,1,1,1,2,1,2,1,1,2 -1214966,9,7,7,5,5,10,7,8,3,4 -1216694,10,8,8,4,10,10,8,1,1,4 -1216947,1,1,1,1,2,1,3,1,1,2 -1217051,5,1,1,1,2,1,3,1,1,2 -1217264,1,1,1,1,2,1,3,1,1,2 -1218105,5,10,10,9,6,10,7,10,5,4 -1218741,10,10,9,3,7,5,3,5,1,4 -1218860,1,1,1,1,1,1,3,1,1,2 -1218860,1,1,1,1,1,1,3,1,1,2 -1219406,5,1,1,1,1,1,3,1,1,2 -1219525,8,10,10,10,5,10,8,10,6,4 -1219859,8,10,8,8,4,8,7,7,1,4 -1220330,1,1,1,1,2,1,3,1,1,2 -1221863,10,10,10,10,7,10,7,10,4,4 -1222047,10,10,10,10,3,10,10,6,1,4 -1222936,8,7,8,7,5,5,5,10,2,4 -1223282,1,1,1,1,2,1,2,1,1,2 -1223426,1,1,1,1,2,1,3,1,1,2 -1223793,6,10,7,7,6,4,8,10,2,4 -1223967,6,1,3,1,2,1,3,1,1,2 -1224329,1,1,1,2,2,1,3,1,1,2 -1225799,10,6,4,3,10,10,9,10,1,4 -1226012,4,1,1,3,1,5,2,1,1,4 -1226612,7,5,6,3,3,8,7,4,1,4 -1227210,10,5,5,6,3,10,7,9,2,4 -1227244,1,1,1,1,2,1,2,1,1,2 -1227481,10,5,7,4,4,10,8,9,1,4 -1228152,8,9,9,5,3,5,7,7,1,4 -1228311,1,1,1,1,1,1,3,1,1,2 -1230175,10,10,10,3,10,10,9,10,1,4 -1230688,7,4,7,4,3,7,7,6,1,4 -1231387,6,8,7,5,6,8,8,9,2,4 -1231706,8,4,6,3,3,1,4,3,1,2 -1232225,10,4,5,5,5,10,4,1,1,4 -1236043,3,3,2,1,3,1,3,6,1,2 -1241559,10,8,8,2,8,10,4,8,10,4 -1241679,9,8,8,5,6,2,4,10,4,4 -1242364,8,10,10,8,6,9,3,10,10,4 -1243256,10,4,3,2,3,10,5,3,2,4 -1270479,5,1,3,3,2,2,2,3,1,2 -1276091,3,1,1,3,1,1,3,1,1,2 -1277018,2,1,1,1,2,1,3,1,1,2 -128059,1,1,1,1,2,5,5,1,1,2 -1285531,1,1,1,1,2,1,3,1,1,2 -1287775,5,1,1,2,2,2,3,1,1,2 -144888,8,10,10,8,5,10,7,8,1,4 -145447,8,4,4,1,2,9,3,3,1,4 -167528,4,1,1,1,2,1,3,6,1,2 -183913,1,2,2,1,2,1,1,1,1,2 -191250,10,4,4,10,2,10,5,3,3,4 -1017023,6,3,3,5,3,10,3,5,3,2 -1100524,6,10,10,2,8,10,7,3,3,4 -1116116,9,10,10,1,10,8,3,3,1,4 -1168736,5,6,6,2,4,10,3,6,1,4 -1182404,3,1,1,1,2,1,1,1,1,2 -1182404,3,1,1,1,2,1,2,1,1,2 -1198641,3,1,1,1,2,1,3,1,1,2 -242970,5,7,7,1,5,8,3,4,1,2 -255644,10,5,8,10,3,10,5,1,3,4 -263538,5,10,10,6,10,10,10,6,5,4 -274137,8,8,9,4,5,10,7,8,1,4 -303213,10,4,4,10,6,10,5,5,1,4 -314428,7,9,4,10,10,3,5,3,3,4 -1182404,5,1,4,1,2,1,3,2,1,2 -1198641,10,10,6,3,3,10,4,3,2,4 -320675,3,3,5,2,3,10,7,1,1,4 -324427,10,8,8,2,3,4,8,7,8,4 -385103,1,1,1,1,2,1,3,1,1,2 -390840,8,4,7,1,3,10,3,9,2,4 -411453,5,1,1,1,2,1,3,1,1,2 -320675,3,3,5,2,3,10,7,1,1,4 -428903,7,2,4,1,3,4,3,3,1,4 -431495,3,1,1,1,2,1,3,2,1,2 -434518,3,1,1,1,2,1,2,1,1,2 -452264,1,1,1,1,2,1,2,1,1,2 -456282,1,1,1,1,2,1,3,1,1,2 -476903,10,5,7,3,3,7,3,3,8,4 -486283,3,1,1,1,2,1,3,1,1,2 -486662,2,1,1,2,2,1,3,1,1,2 -488173,1,4,3,10,4,10,5,6,1,4 -492268,10,4,6,1,2,10,5,3,1,4 -508234,7,4,5,10,2,10,3,8,2,4 -527363,8,10,10,10,8,10,10,7,3,4 -529329,10,10,10,10,10,10,4,10,10,4 -535331,3,1,1,1,3,1,2,1,1,2 -543558,6,1,3,1,4,5,5,10,1,4 -555977,5,6,6,8,6,10,4,10,4,4 -560680,1,1,1,1,2,1,1,1,1,2 -561477,1,1,1,1,2,1,3,1,1,2 -601265,10,4,4,6,2,10,2,3,1,4 -606722,5,5,7,8,6,10,7,4,1,4 -616240,5,3,4,3,4,5,4,7,1,2 -625201,8,2,1,1,5,1,1,1,1,2 -63375,9,1,2,6,4,10,7,7,2,4 -635844,8,4,10,5,4,4,7,10,1,4 -636130,1,1,1,1,2,1,3,1,1,2 -640744,10,10,10,7,9,10,7,10,10,4 -646904,1,1,1,1,2,1,3,1,1,2 -653777,8,3,4,9,3,10,3,3,1,4 -659642,10,8,4,4,4,10,3,10,4,4 -666090,1,1,1,1,2,1,3,1,1,2 -666942,1,1,1,1,2,1,3,1,1,2 -667204,7,8,7,6,4,3,8,8,4,4 -673637,3,1,1,1,2,5,5,1,1,2 -684955,2,1,1,1,3,1,2,1,1,2 -688033,1,1,1,1,2,1,1,1,1,2 -691628,8,6,4,10,10,1,3,5,1,4 -693702,1,1,1,1,2,1,1,1,1,2 -704097,1,1,1,1,1,1,2,1,1,2 -706426,5,5,5,2,5,10,4,3,1,4 -709287,6,8,7,8,6,8,8,9,1,4 -718641,1,1,1,1,5,1,3,1,1,2 -721482,4,4,4,4,6,5,7,3,1,2 -730881,7,6,3,2,5,10,7,4,6,4 -733639,3,1,1,1,2,1,3,1,1,2 -733823,5,4,6,10,2,10,4,1,1,4 -740492,1,1,1,1,2,1,3,1,1,2 -743348,3,2,2,1,2,1,2,3,1,2 -752904,10,1,1,1,2,10,5,4,1,4 -756136,1,1,1,1,2,1,2,1,1,2 -760001,8,10,3,2,6,4,3,10,1,4 -760239,10,4,6,4,5,10,7,1,1,4 -76389,10,4,7,2,2,8,6,1,1,4 -764974,5,1,1,1,2,1,3,1,2,2 -770066,5,2,2,2,2,1,2,2,1,2 -785208,5,4,6,6,4,10,4,3,1,4 -785615,8,6,7,3,3,10,3,4,2,4 -792744,1,1,1,1,2,1,1,1,1,2 -797327,6,5,5,8,4,10,3,4,1,4 -798429,1,1,1,1,2,1,3,1,1,2 -704097,1,1,1,1,1,1,2,1,1,2 -806423,8,5,5,5,2,10,4,3,1,4 -809912,10,3,3,1,2,10,7,6,1,4 -810104,1,1,1,1,2,1,3,1,1,2 -814265,2,1,1,1,2,1,1,1,1,2 -814911,1,1,1,1,2,1,1,1,1,2 -822829,7,6,4,8,10,10,9,5,3,4 -826923,1,1,1,1,2,1,1,1,1,2 -830690,5,2,2,2,3,1,1,3,1,2 -831268,1,1,1,1,1,1,1,3,1,2 -832226,3,4,4,10,5,1,3,3,1,4 -832567,4,2,3,5,3,8,7,6,1,4 -836433,5,1,1,3,2,1,1,1,1,2 -837082,2,1,1,1,2,1,3,1,1,2 -846832,3,4,5,3,7,3,4,6,1,2 -850831,2,7,10,10,7,10,4,9,4,4 -855524,1,1,1,1,2,1,2,1,1,2 -857774,4,1,1,1,3,1,2,2,1,2 -859164,5,3,3,1,3,3,3,3,3,4 -859350,8,10,10,7,10,10,7,3,8,4 -866325,8,10,5,3,8,4,4,10,3,4 -873549,10,3,5,4,3,7,3,5,3,4 -877291,6,10,10,10,10,10,8,10,10,4 -877943,3,10,3,10,6,10,5,1,4,4 -888169,3,2,2,1,4,3,2,1,1,2 -888523,4,4,4,2,2,3,2,1,1,2 -896404,2,1,1,1,2,1,3,1,1,2 -897172,2,1,1,1,2,1,2,1,1,2 -95719,6,10,10,10,8,10,7,10,7,4 -160296,5,8,8,10,5,10,8,10,3,4 -342245,1,1,3,1,2,1,1,1,1,2 -428598,1,1,3,1,1,1,2,1,1,2 -492561,4,3,2,1,3,1,2,1,1,2 -493452,1,1,3,1,2,1,1,1,1,2 -493452,4,1,2,1,2,1,2,1,1,2 -521441,5,1,1,2,2,1,2,1,1,2 -560680,3,1,2,1,2,1,2,1,1,2 -636437,1,1,1,1,2,1,1,1,1,2 -640712,1,1,1,1,2,1,2,1,1,2 -654244,1,1,1,1,1,1,2,1,1,2 -657753,3,1,1,4,3,1,2,2,1,2 -685977,5,3,4,1,4,1,3,1,1,2 -805448,1,1,1,1,2,1,1,1,1,2 -846423,10,6,3,6,4,10,7,8,4,4 -1002504,3,2,2,2,2,1,3,2,1,2 -1022257,2,1,1,1,2,1,1,1,1,2 -1026122,2,1,1,1,2,1,1,1,1,2 -1071084,3,3,2,2,3,1,1,2,3,2 -1080233,7,6,6,3,2,10,7,1,1,4 -1114570,5,3,3,2,3,1,3,1,1,2 -1114570,2,1,1,1,2,1,2,2,1,2 -1116715,5,1,1,1,3,2,2,2,1,2 -1131411,1,1,1,2,2,1,2,1,1,2 -1151734,10,8,7,4,3,10,7,9,1,4 -1156017,3,1,1,1,2,1,2,1,1,2 -1158247,1,1,1,1,1,1,1,1,1,2 -1158405,1,2,3,1,2,1,2,1,1,2 -1168278,3,1,1,1,2,1,2,1,1,2 -1176187,3,1,1,1,2,1,3,1,1,2 -1196263,4,1,1,1,2,1,1,1,1,2 -1196475,3,2,1,1,2,1,2,2,1,2 -1206314,1,2,3,1,2,1,1,1,1,2 -1211265,3,10,8,7,6,9,9,3,8,4 -1213784,3,1,1,1,2,1,1,1,1,2 -1223003,5,3,3,1,2,1,2,1,1,2 -1223306,3,1,1,1,2,4,1,1,1,2 -1223543,1,2,1,3,2,1,1,2,1,2 -1229929,1,1,1,1,2,1,2,1,1,2 -1231853,4,2,2,1,2,1,2,1,1,2 -1234554,1,1,1,1,2,1,2,1,1,2 -1236837,2,3,2,2,2,2,3,1,1,2 -1237674,3,1,2,1,2,1,2,1,1,2 -1238021,1,1,1,1,2,1,2,1,1,2 -1238633,10,10,10,6,8,4,8,5,1,4 -1238915,5,1,2,1,2,1,3,1,1,2 -1238948,8,5,6,2,3,10,6,6,1,4 -1239232,3,3,2,6,3,3,3,5,1,2 -1239347,8,7,8,5,10,10,7,2,1,4 -1239967,1,1,1,1,2,1,2,1,1,2 -1240337,5,2,2,2,2,2,3,2,2,2 -1253505,2,3,1,1,5,1,1,1,1,2 -1255384,3,2,2,3,2,3,3,1,1,2 -1257200,10,10,10,7,10,10,8,2,1,4 -1257648,4,3,3,1,2,1,3,3,1,2 -1257815,5,1,3,1,2,1,2,1,1,2 -1257938,3,1,1,1,2,1,1,1,1,2 -1258549,9,10,10,10,10,10,10,10,1,4 -1258556,5,3,6,1,2,1,1,1,1,2 -1266154,8,7,8,2,4,2,5,10,1,4 -1272039,1,1,1,1,2,1,2,1,1,2 -1276091,2,1,1,1,2,1,2,1,1,2 -1276091,1,3,1,1,2,1,2,2,1,2 -1276091,5,1,1,3,4,1,3,2,1,2 -1277629,5,1,1,1,2,1,2,2,1,2 -1293439,3,2,2,3,2,1,1,1,1,2 -1293439,6,9,7,5,5,8,4,2,1,2 -1294562,10,8,10,1,3,10,5,1,1,4 -1295186,10,10,10,1,6,1,2,8,1,4 -527337,4,1,1,1,2,1,1,1,1,2 -558538,4,1,3,3,2,1,1,1,1,2 -566509,5,1,1,1,2,1,1,1,1,2 -608157,10,4,3,10,4,10,10,1,1,4 -677910,5,2,2,4,2,4,1,1,1,2 -734111,1,1,1,3,2,3,1,1,1,2 -734111,1,1,1,1,2,2,1,1,1,2 -780555,5,1,1,6,3,1,2,1,1,2 -827627,2,1,1,1,2,1,1,1,1,2 -1049837,1,1,1,1,2,1,1,1,1,2 -1058849,5,1,1,1,2,1,1,1,1,2 -1182404,1,1,1,1,1,1,1,1,1,2 -1193544,5,7,9,8,6,10,8,10,1,4 -1201870,4,1,1,3,1,1,2,1,1,2 -1202253,5,1,1,1,2,1,1,1,1,2 -1227081,3,1,1,3,2,1,1,1,1,2 -1230994,4,5,5,8,6,10,10,7,1,4 -1238410,2,3,1,1,3,1,1,1,1,2 -1246562,10,2,2,1,2,6,1,1,2,4 -1257470,10,6,5,8,5,10,8,6,1,4 -1259008,8,8,9,6,6,3,10,10,1,4 -1266124,5,1,2,1,2,1,1,1,1,2 -1267898,5,1,3,1,2,1,1,1,1,2 -1268313,5,1,1,3,2,1,1,1,1,2 -1268804,3,1,1,1,2,5,1,1,1,2 -1276091,6,1,1,3,2,1,1,1,1,2 -1280258,4,1,1,1,2,1,1,2,1,2 -1293966,4,1,1,1,2,1,1,1,1,2 -1296572,10,9,8,7,6,4,7,10,3,4 -1298416,10,6,6,2,4,10,9,7,1,4 -1299596,6,6,6,5,4,10,7,6,2,4 -1105524,4,1,1,1,2,1,1,1,1,2 -1181685,1,1,2,1,2,1,2,1,1,2 -1211594,3,1,1,1,1,1,2,1,1,2 -1238777,6,1,1,3,2,1,1,1,1,2 -1257608,6,1,1,1,1,1,1,1,1,2 -1269574,4,1,1,1,2,1,1,1,1,2 -1277145,5,1,1,1,2,1,1,1,1,2 -1287282,3,1,1,1,2,1,1,1,1,2 -1296025,4,1,2,1,2,1,1,1,1,2 -1296263,4,1,1,1,2,1,1,1,1,2 -1296593,5,2,1,1,2,1,1,1,1,2 -1299161,4,8,7,10,4,10,7,5,1,4 -1301945,5,1,1,1,1,1,1,1,1,2 -1302428,5,3,2,4,2,1,1,1,1,2 -1318169,9,10,10,10,10,5,10,10,10,4 -474162,8,7,8,5,5,10,9,10,1,4 -787451,5,1,2,1,2,1,1,1,1,2 -1002025,1,1,1,3,1,3,1,1,1,2 -1070522,3,1,1,1,1,1,2,1,1,2 -1073960,10,10,10,10,6,10,8,1,5,4 -1076352,3,6,4,10,3,3,3,4,1,4 -1084139,6,3,2,1,3,4,4,1,1,4 -1115293,1,1,1,1,2,1,1,1,1,2 -1119189,5,8,9,4,3,10,7,1,1,4 -1133991,4,1,1,1,1,1,2,1,1,2 -1142706,5,10,10,10,6,10,6,5,2,4 -1155967,5,1,2,10,4,5,2,1,1,2 -1170945,3,1,1,1,1,1,2,1,1,2 -1181567,1,1,1,1,1,1,1,1,1,2 -1182404,4,2,1,1,2,1,1,1,1,2 -1204558,4,1,1,1,2,1,2,1,1,2 -1217952,4,1,1,1,2,1,2,1,1,2 -1224565,6,1,1,1,2,1,3,1,1,2 -1238186,4,1,1,1,2,1,2,1,1,2 -1253917,4,1,1,2,2,1,2,1,1,2 -1265899,4,1,1,1,2,1,3,1,1,2 -1268766,1,1,1,1,2,1,1,1,1,2 -1277268,3,3,1,1,2,1,1,1,1,2 -1286943,8,10,10,10,7,5,4,8,7,4 -1295508,1,1,1,1,2,4,1,1,1,2 -1297327,5,1,1,1,2,1,1,1,1,2 -1297522,2,1,1,1,2,1,1,1,1,2 -1298360,1,1,1,1,2,1,1,1,1,2 -1299924,5,1,1,1,2,1,2,1,1,2 -1299994,5,1,1,1,2,1,1,1,1,2 -1304595,3,1,1,1,1,1,2,1,1,2 -1306282,6,6,7,10,3,10,8,10,2,4 -1313325,4,10,4,7,3,10,9,10,1,4 -1320077,1,1,1,1,1,1,1,1,1,2 -1320077,1,1,1,1,1,1,2,1,1,2 -1320304,3,1,2,2,2,1,1,1,1,2 -1330439,4,7,8,3,4,10,9,1,1,4 -333093,1,1,1,1,3,1,1,1,1,2 -369565,4,1,1,1,3,1,1,1,1,2 -412300,10,4,5,4,3,5,7,3,1,4 -672113,7,5,6,10,4,10,5,3,1,4 -749653,3,1,1,1,2,1,2,1,1,2 -769612,3,1,1,2,2,1,1,1,1,2 -769612,4,1,1,1,2,1,1,1,1,2 -798429,4,1,1,1,2,1,3,1,1,2 -807657,6,1,3,2,2,1,1,1,1,2 -8233704,4,1,1,1,1,1,2,1,1,2 -837480,7,4,4,3,4,10,6,9,1,4 -867392,4,2,2,1,2,1,2,1,1,2 -869828,1,1,1,1,1,1,3,1,1,2 -1043068,3,1,1,1,2,1,2,1,1,2 -1056171,2,1,1,1,2,1,2,1,1,2 -1061990,1,1,3,2,2,1,3,1,1,2 -1113061,5,1,1,1,2,1,3,1,1,2 -1116192,5,1,2,1,2,1,3,1,1,2 -1135090,4,1,1,1,2,1,2,1,1,2 -1145420,6,1,1,1,2,1,2,1,1,2 -1158157,5,1,1,1,2,2,2,1,1,2 -1171578,3,1,1,1,2,1,1,1,1,2 -1174841,5,3,1,1,2,1,1,1,1,2 -1184586,4,1,1,1,2,1,2,1,1,2 -1186936,2,1,3,2,2,1,2,1,1,2 -1197527,5,1,1,1,2,1,2,1,1,2 -1222464,6,10,10,10,4,10,7,10,1,4 -1240603,2,1,1,1,1,1,1,1,1,2 -1240603,3,1,1,1,1,1,1,1,1,2 -1241035,7,8,3,7,4,5,7,8,2,4 -1287971,3,1,1,1,2,1,2,1,1,2 -1289391,1,1,1,1,2,1,3,1,1,2 -1299924,3,2,2,2,2,1,4,2,1,2 -1306339,4,4,2,1,2,5,2,1,2,2 -1313658,3,1,1,1,2,1,1,1,1,2 -1313982,4,3,1,1,2,1,4,8,1,2 -1321264,5,2,2,2,1,1,2,1,1,2 -1321321,5,1,1,3,2,1,1,1,1,2 -1321348,2,1,1,1,2,1,2,1,1,2 -1321931,5,1,1,1,2,1,2,1,1,2 -1321942,5,1,1,1,2,1,3,1,1,2 -1321942,5,1,1,1,2,1,3,1,1,2 -1328331,1,1,1,1,2,1,3,1,1,2 -1328755,3,1,1,1,2,1,2,1,1,2 -1331405,4,1,1,1,2,1,3,2,1,2 -1331412,5,7,10,10,5,10,10,10,1,4 -1333104,3,1,2,1,2,1,3,1,1,2 -1334071,4,1,1,1,2,3,2,1,1,2 -1343068,8,4,4,1,6,10,2,5,2,4 -1343374,10,10,8,10,6,5,10,3,1,4 -1344121,8,10,4,4,8,10,8,2,1,4 -142932,7,6,10,5,3,10,9,10,2,4 -183936,3,1,1,1,2,1,2,1,1,2 -324382,1,1,1,1,2,1,2,1,1,2 -378275,10,9,7,3,4,2,7,7,1,4 -385103,5,1,2,1,2,1,3,1,1,2 -690557,5,1,1,1,2,1,2,1,1,2 -695091,1,1,1,1,2,1,2,1,1,2 -695219,1,1,1,1,2,1,2,1,1,2 -824249,1,1,1,1,2,1,3,1,1,2 -871549,5,1,2,1,2,1,2,1,1,2 -878358,5,7,10,6,5,10,7,5,1,4 -1107684,6,10,5,5,4,10,6,10,1,4 -1115762,3,1,1,1,2,1,1,1,1,2 -1217717,5,1,1,6,3,1,1,1,1,2 -1239420,1,1,1,1,2,1,1,1,1,2 -1254538,8,10,10,10,6,10,10,10,1,4 -1261751,5,1,1,1,2,1,2,2,1,2 -1268275,9,8,8,9,6,3,4,1,1,4 -1272166,5,1,1,1,2,1,1,1,1,2 -1294261,4,10,8,5,4,1,10,1,1,4 -1295529,2,5,7,6,4,10,7,6,1,4 -1298484,10,3,4,5,3,10,4,1,1,4 -1311875,5,1,2,1,2,1,1,1,1,2 -1315506,4,8,6,3,4,10,7,1,1,4 -1320141,5,1,1,1,2,1,2,1,1,2 -1325309,4,1,2,1,2,1,2,1,1,2 -1333063,5,1,3,1,2,1,3,1,1,2 -1333495,3,1,1,1,2,1,2,1,1,2 -1334659,5,2,4,1,1,1,1,1,1,2 -1336798,3,1,1,1,2,1,2,1,1,2 -1344449,1,1,1,1,1,1,2,1,1,2 -1350568,4,1,1,1,2,1,2,1,1,2 -1352663,5,4,6,8,4,1,8,10,1,4 -188336,5,3,2,8,5,10,8,1,2,4 -352431,10,5,10,3,5,8,7,8,3,4 -353098,4,1,1,2,2,1,1,1,1,2 -411453,1,1,1,1,2,1,1,1,1,2 -557583,5,10,10,10,10,10,10,1,1,4 -636375,5,1,1,1,2,1,1,1,1,2 -736150,10,4,3,10,3,10,7,1,2,4 -803531,5,10,10,10,5,2,8,5,1,4 -822829,8,10,10,10,6,10,10,10,10,4 -1016634,2,3,1,1,2,1,2,1,1,2 -1031608,2,1,1,1,1,1,2,1,1,2 -1041043,4,1,3,1,2,1,2,1,1,2 -1042252,3,1,1,1,2,1,2,1,1,2 -1061990,4,1,1,1,2,1,2,1,1,2 -1073836,5,1,1,1,2,1,2,1,1,2 -1083817,3,1,1,1,2,1,2,1,1,2 -1096352,6,3,3,3,3,2,6,1,1,2 -1140597,7,1,2,3,2,1,2,1,1,2 -1149548,1,1,1,1,2,1,1,1,1,2 -1174009,5,1,1,2,1,1,2,1,1,2 -1183596,3,1,3,1,3,4,1,1,1,2 -1190386,4,6,6,5,7,6,7,7,3,4 -1190546,2,1,1,1,2,5,1,1,1,2 -1213273,2,1,1,1,2,1,1,1,1,2 -1218982,4,1,1,1,2,1,1,1,1,2 -1225382,6,2,3,1,2,1,1,1,1,2 -1235807,5,1,1,1,2,1,2,1,1,2 -1238777,1,1,1,1,2,1,1,1,1,2 -1253955,8,7,4,4,5,3,5,10,1,4 -1257366,3,1,1,1,2,1,1,1,1,2 -1260659,3,1,4,1,2,1,1,1,1,2 -1268952,10,10,7,8,7,1,10,10,3,4 -1275807,4,2,4,3,2,2,2,1,1,2 -1277792,4,1,1,1,2,1,1,1,1,2 -1277792,5,1,1,3,2,1,1,1,1,2 -1285722,4,1,1,3,2,1,1,1,1,2 -1288608,3,1,1,1,2,1,2,1,1,2 -1290203,3,1,1,1,2,1,2,1,1,2 -1294413,1,1,1,1,2,1,1,1,1,2 -1299596,2,1,1,1,2,1,1,1,1,2 -1303489,3,1,1,1,2,1,2,1,1,2 -1311033,1,2,2,1,2,1,1,1,1,2 -1311108,1,1,1,3,2,1,1,1,1,2 -1315807,5,10,10,10,10,2,10,10,10,4 -1318671,3,1,1,1,2,1,2,1,1,2 -1319609,3,1,1,2,3,4,1,1,1,2 -1323477,1,2,1,3,2,1,2,1,1,2 -1324572,5,1,1,1,2,1,2,2,1,2 -1324681,4,1,1,1,2,1,2,1,1,2 -1325159,3,1,1,1,2,1,3,1,1,2 -1326892,3,1,1,1,2,1,2,1,1,2 -1330361,5,1,1,1,2,1,2,1,1,2 -1333877,5,4,5,1,8,1,3,6,1,2 -1334015,7,8,8,7,3,10,7,2,3,4 -1334667,1,1,1,1,2,1,1,1,1,2 -1339781,1,1,1,1,2,1,2,1,1,2 -1339781,4,1,1,1,2,1,3,1,1,2 -13454352,1,1,3,1,2,1,2,1,1,2 -1345452,1,1,3,1,2,1,2,1,1,2 -1345593,3,1,1,3,2,1,2,1,1,2 -1347749,1,1,1,1,2,1,1,1,1,2 -1347943,5,2,2,2,2,1,1,1,2,2 -1348851,3,1,1,1,2,1,3,1,1,2 -1350319,5,7,4,1,6,1,7,10,3,4 -1350423,5,10,10,8,5,5,7,10,1,4 -1352848,3,10,7,8,5,8,7,4,1,4 -1353092,3,2,1,2,2,1,3,1,1,2 -1354840,2,1,1,1,2,1,3,1,1,2 -1354840,5,3,2,1,3,1,1,1,1,2 -1355260,1,1,1,1,2,1,2,1,1,2 -1365075,4,1,4,1,2,1,1,1,1,2 -1365328,1,1,2,1,2,1,2,1,1,2 -1368267,5,1,1,1,2,1,1,1,1,2 -1368273,1,1,1,1,2,1,1,1,1,2 -1368882,2,1,1,1,2,1,1,1,1,2 -1369821,10,10,10,10,5,10,10,10,7,4 -1371026,5,10,10,10,4,10,5,6,3,4 -1371920,5,1,1,1,2,1,3,2,1,2 -466906,1,1,1,1,2,1,1,1,1,2 -466906,1,1,1,1,2,1,1,1,1,2 -534555,1,1,1,1,2,1,1,1,1,2 -536708,1,1,1,1,2,1,1,1,1,2 -566346,3,1,1,1,2,1,2,3,1,2 -603148,4,1,1,1,2,1,1,1,1,2 -654546,1,1,1,1,2,1,1,1,8,2 -654546,1,1,1,3,2,1,1,1,1,2 -695091,5,10,10,5,4,5,4,4,1,4 -714039,3,1,1,1,2,1,1,1,1,2 -763235,3,1,1,1,2,1,2,1,2,2 -776715,3,1,1,1,3,2,1,1,1,2 -841769,2,1,1,1,2,1,1,1,1,2 -888820,5,10,10,3,7,3,8,10,2,4 -897471,4,8,6,4,3,4,10,6,1,4 -897471,4,8,8,5,4,5,10,4,1,4 diff --git a/breast-cancer-wisconsin.names b/breast-cancer-wisconsin.names deleted file mode 100644 index 54b59a12..00000000 --- a/breast-cancer-wisconsin.names +++ /dev/null @@ -1,126 +0,0 @@ -Citation Request: - This breast cancer databases was obtained from the University of Wisconsin - Hospitals, Madison from Dr. William H. Wolberg. If you publish results - when using this database, then please include this information in your - acknowledgements. Also, please cite one or more of: - - 1. O. L. Mangasarian and W. H. Wolberg: "Cancer diagnosis via linear - programming", SIAM News, Volume 23, Number 5, September 1990, pp 1 & 18. - - 2. William H. Wolberg and O.L. Mangasarian: "Multisurface method of - pattern separation for medical diagnosis applied to breast cytology", - Proceedings of the National Academy of Sciences, U.S.A., Volume 87, - December 1990, pp 9193-9196. - - 3. O. L. Mangasarian, R. Setiono, and W.H. Wolberg: "Pattern recognition - via linear programming: Theory and application to medical diagnosis", - in: "Large-scale numerical optimization", Thomas F. Coleman and Yuying - Li, editors, SIAM Publications, Philadelphia 1990, pp 22-30. - - 4. K. P. Bennett & O. L. Mangasarian: "Robust linear programming - discrimination of two linearly inseparable sets", Optimization Methods - and Software 1, 1992, 23-34 (Gordon & Breach Science Publishers). - -1. Title: Wisconsin Breast Cancer Database (January 8, 1991) - -2. Sources: - -- Dr. WIlliam H. Wolberg (physician) - University of Wisconsin Hospitals - Madison, Wisconsin - USA - -- Donor: Olvi Mangasarian (mangasarian@cs.wisc.edu) - Received by David W. Aha (aha@cs.jhu.edu) - -- Date: 15 July 1992 - -3. Past Usage: - - Attributes 2 through 10 have been used to represent instances. - Each instance has one of 2 possible classes: benign or malignant. - - 1. Wolberg,~W.~H., \& Mangasarian,~O.~L. (1990). Multisurface method of - pattern separation for medical diagnosis applied to breast cytology. In - {\it Proceedings of the National Academy of Sciences}, {\it 87}, - 9193--9196. - -- Size of data set: only 369 instances (at that point in time) - -- Collected classification results: 1 trial only - -- Two pairs of parallel hyperplanes were found to be consistent with - 50% of the data - -- Accuracy on remaining 50% of dataset: 93.5% - -- Three pairs of parallel hyperplanes were found to be consistent with - 67% of data - -- Accuracy on remaining 33% of dataset: 95.9% - - 2. Zhang,~J. (1992). Selecting typical instances in instance-based - learning. In {\it Proceedings of the Ninth International Machine - Learning Conference} (pp. 470--479). Aberdeen, Scotland: Morgan - Kaufmann. - -- Size of data set: only 369 instances (at that point in time) - -- Applied 4 instance-based learning algorithms - -- Collected classification results averaged over 10 trials - -- Best accuracy result: - -- 1-nearest neighbor: 93.7% - -- trained on 200 instances, tested on the other 169 - -- Also of interest: - -- Using only typical instances: 92.2% (storing only 23.1 instances) - -- trained on 200 instances, tested on the other 169 - -4. Relevant Information: - - Samples arrive periodically as Dr. Wolberg reports his clinical cases. - The database therefore reflects this chronological grouping of the data. - This grouping information appears immediately below, having been removed - from the data itself: - - Group 1: 367 instances (January 1989) - Group 2: 70 instances (October 1989) - Group 3: 31 instances (February 1990) - Group 4: 17 instances (April 1990) - Group 5: 48 instances (August 1990) - Group 6: 49 instances (Updated January 1991) - Group 7: 31 instances (June 1991) - Group 8: 86 instances (November 1991) - ----------------------------------------- - Total: 699 points (as of the donated datbase on 15 July 1992) - - Note that the results summarized above in Past Usage refer to a dataset - of size 369, while Group 1 has only 367 instances. This is because it - originally contained 369 instances; 2 were removed. The following - statements summarizes changes to the original Group 1's set of data: - - ##### Group 1 : 367 points: 200B 167M (January 1989) - ##### Revised Jan 10, 1991: Replaced zero bare nuclei in 1080185 & 1187805 - ##### Revised Nov 22,1991: Removed 765878,4,5,9,7,10,10,10,3,8,1 no record - ##### : Removed 484201,2,7,8,8,4,3,10,3,4,1 zero epithelial - ##### : Changed 0 to 1 in field 6 of sample 1219406 - ##### : Changed 0 to 1 in field 8 of following sample: - ##### : 1182404,2,3,1,1,1,2,0,1,1,1 - -5. Number of Instances: 699 (as of 15 July 1992) - -6. Number of Attributes: 10 plus the class attribute - -7. Attribute Information: (class attribute has been moved to last column) - - # Attribute Domain - -- ----------------------------------------- - 1. Sample code number id number - 2. Clump Thickness 1 - 10 - 3. Uniformity of Cell Size 1 - 10 - 4. Uniformity of Cell Shape 1 - 10 - 5. Marginal Adhesion 1 - 10 - 6. Single Epithelial Cell Size 1 - 10 - 7. Bare Nuclei 1 - 10 - 8. Bland Chromatin 1 - 10 - 9. Normal Nucleoli 1 - 10 - 10. Mitoses 1 - 10 - 11. Class: (2 for benign, 4 for malignant) - -8. Missing attribute values: 16 - - There are 16 instances in Groups 1 to 6 that contain a single missing - (i.e., unavailable) attribute value, now denoted by "?". - -9. Class distribution: - - Benign: 458 (65.5%) - Malignant: 241 (34.5%) diff --git a/coq/API.v b/coq/API.v deleted file mode 100644 index ac550a47..00000000 --- a/coq/API.v +++ /dev/null @@ -1,58 +0,0 @@ -Require Import FloatishIEEE. -Require Import ExtrFloatishIEEE. - - -(* Require Import ExtrR. *) -(* Our stuff modules *) - -Require Import Utils. -Require Import Vector. -Require Gen_NN. -Require Import DefinedFunctions. -Require Import FloatishDef. -Require Import BinInt. -Require Import String. -Require Import Streams. -Local Open Scope list. - -Existing Instance floatish_IEEE. - -Example test := - mk_env_entry (Name "f", DTfloat) (FfromZ 1)%Z :: - mk_env_entry (Name "v", DTVector 3) (ConstVector 3 ((FfromZ (-2)))%Z) :: - mk_env_entry (Name "m", DTMatrix 2 3) (ConstMatrix 2 3 (FfromZ 3))%Z :: nil. -Module API. - Example opt := @Gen_NN.opt floatish_IEEE. - Example opt2 := @Gen_NN.opt2 floatish_IEEE. - Example test_update := @Gen_NN.test_update floatish_IEEE. - Example testopt := @Gen_NN.testopt floatish_IEEE. - Example testreeopt := @Gen_NN.testreeopt floatish_IEEE. - Example gradenv := @Gen_NN.gradenv floatish_IEEE. - Example gradenv_tree := @Gen_NN.gradenv_tree floatish_IEEE. - Example test_env := test. - - Example discard_first {A} (l:list (list A)) : list (list A) := List.map (@List.tl A) l. - Definition normalizeIntData := Gen_NN.normalizeIntData. - Definition init_env2 := Gen_NN.init_env2. - CoFixpoint mkIndexedStream {A} (i : nat) (ran : nat -> A) : Stream A := - Cons (ran i) (mkIndexedStream (S i) ran). - Definition streamtake := Gen_NN.streamtake. - Definition df_env := DefinedFunctions.df_env. - Definition eval_wisconsin_batch (nsamp:nat) - (env:df_env) (data : Matrix float nsamp 10) : list float := - match Gen_NN.eval_wisconsin_batch nsamp env data with - | Some val => val :: nil - | _ => nil - end. - - Definition wisconsin_test := Gen_NN.wisconsin_test. - Definition wisconsin_test_env := Gen_NN.wisconsin_test_env. - Definition wisconsin_gradenv_tree := Gen_NN.wisconsin_gradenv_tree. - Definition wisconsin_gradenv := Gen_NN.wisconsin_gradenv. - Definition nn_test := Gen_NN.NN_test. - Definition nn_test_val := Gen_NN.NN_test_val. - Definition nn_test_env := Gen_NN.NN_test_env. - Definition nn_test_gradenv_tree := Gen_NN.NN_test_gradenv_tree. - Definition nn_test_gradenv := Gen_NN.NN_test_gradenv. - - End API. diff --git a/coq/FHE/arith.v b/coq/FHE/arith.v deleted file mode 100644 index 913793d2..00000000 --- a/coq/FHE/arith.v +++ /dev/null @@ -1,3455 +0,0 @@ -Require Import Reals Lra Lia List Permutation. -From mathcomp Require Import common ssreflect fintype bigop ssrnat matrix Rstruct complex seq fingroup. -From mathcomp Require Import ssralg ssrfun. -From mathcomp Require Import generic_quotient ring_quotient. -From mathcomp Require Import poly mxpoly polydiv ssrint zmodp eqtype ssrbool div order. -From mathcomp Require Import ring ssrZ. - -Import ssralg.GRing. -Import ssrnum.Num.Theory. -Require Import nth_root encode zp_prim_root. - -Ltac coq_lra := lra. -From mathcomp Require Import lra. - -Set Bullet Behavior "Strict Subproofs". - - -Local Open Scope ring_scope. - - -Record ENCODE : Type := - mkENCODE - { clear : {poly int} ; - scale : nat - }. - -Record FHE : Type := - mkFHE - { q : nat; - cypher : {poly {poly 'Z_q}} ; - norm_bound : R ; - noise_bound : R - }. - -Definition FHE_add {q : nat} (P Q : {poly {poly 'Z_q}} ) := P + Q. - -Definition FHE_mult_base {q : nat} (P Q : {poly {poly 'Z_q}} ) := P * Q. - -Definition zliftc {q : nat} (c : 'Z_q) : int := - if (c <= q/2) then c%:Z else c%:Z - q%:Z. - -Lemma zliftc_bound {q : nat} (c : 'Z_q) : - 1 < q -> - `| zliftc c | <= q/2. -Proof. - intros. - rewrite /zliftc. - case: (boolP (c <= q/2)); intros. - - by destruct c. - - destruct c. - simpl in *. - move /negP in i. - assert (m > (Nat.divmod q 1 0 1).1) by lia. - rewrite Zp_cast // in i0. - replace `|m-q| with (q - m)%N by (rewrite distnEr; lia). - assert (q <= 2*(Nat.divmod q 1 0 1).1 + 1). - { - move: (Nat.divmod_spec q 1 0 1 (le_refl _)) => /=. - case: (Nat.divmod q 1 0 1) => /= x u. - rewrite !Nat.add_0_r. - move=> [-> ubound]. - destruct u; lia. - } - lia. - Qed. - -Lemma modp_small (q : nat) (m : nat) : - m < (Zp_trunc q).+2 -> - nat_of_ord (intmul (1 : 'Z_q) (Posz m)) = m. -Proof. - rewrite /intmul Zp_nat /=. - by apply/modn_small. -Qed. - -Lemma modpp (q : nat) : - nat_of_ord (intmul (1 : 'Z_q) (Posz (Zp_trunc q).+2)) = 0%nat. -Proof. - by rewrite /intmul Zp_nat /= modnn. -Qed. - -Lemma modpp' (q : nat) : - 1 < q -> - intmul (1 : 'Z_q) (Posz q) = 0. -Proof. - intros. - apply ord_inj. - by rewrite /intmul Zp_nat /= Zp_cast // modnn. -Qed. - -Lemma zliftc_valid {q : nat} (c : 'Z_q) : - 1 < q -> - (zliftc c) %:~R = c. -Proof. - intros. - rewrite /zliftc. - case: (c <= q/2). - - destruct c. - apply ord_inj => /=. - by rewrite modp_small. - - destruct c. - rewrite intrD mulrNz modpp' // oppr0 addr0. - apply ord_inj => /=. - by rewrite modp_small. -Qed. - -Lemma zliftc_add2 {q : nat} (a b : 'Z_q) : - 1 < q -> - `|zliftc (a + b)%R - ((zliftc a) + (zliftc b))%R | <= q. -Proof. - have diveq: (Nat.divmod q 1 0 1).1 = (q / 2)%nat by []. - - move=> qbig. - rewrite /zliftc /=. - Ltac t1 C := - match goal with - | [|- is_true (leq (absz ?x) _) ] => - have ->: x = C by lia - end. - - Ltac t2 C := t1 C - ; case: (boolP (((Zp_trunc q).+1 < _ + _))) => // _ - ; (rewrite ?mul0n ?mul1n ?Zp_cast // ?distnn); lia. - - case: (boolP ( (a + b) %% (Zp_trunc q).+2 <= (Nat.divmod q 1 0 1).1)) => ltab - ; case: (boolP (a <= (Nat.divmod q 1 0 1).1)) => lta - ; case: (boolP (b <= (Nat.divmod q 1 0 1).1)) => ltb - ; rewrite modnD // [modn a _]modn_small ?[modn b _]modn_small; try apply ltn_ord. - - t2 (opp (Posz (muln (leq (S (S (Zp_trunc q))) (addn a b)) (S (S (Zp_trunc q)))))). - - t2 (add (V:=int_ZmodType) q (opp (Posz (muln (leq (S (S (Zp_trunc q))) (addn a b)) (S (S (Zp_trunc q))))))). - - t2 (add (V:=int_ZmodType) (Posz q) (opp (Posz (muln (leq (S (S (Zp_trunc q))) (addn a b)) (S (S (Zp_trunc q))))))). - - have eqq: ((Zp_trunc q).+1 < a + b). - { move: (Nat.divmod_spec q 1 0 1) lta ltb. - case: (Nat.divmod q 1 0 1) => /= x u. - move/(_ (le_n _)) => [eqq1 eqq2] lta ltb. - rewrite !Nat.add_0_r in eqq1. - rewrite {1}Zp_cast //. - destruct u; simpl in *; lia. - } - rewrite eqq mul1n. - rewrite {1}Zp_cast // in eqq. - rewrite {3}Zp_cast //. - by t1 (Posz q). - - case: (boolP ((Zp_trunc q).+1 < a + b)). - + rewrite mul1n {1}Zp_cast // => ineq2. - simpl in lta. - rewrite diveq in lta ltb ltab. - case: (eqVneq q (nat_of_ord a + nat_of_ord b))%nat =>eqq2. - * case/negP: ltab. - by rewrite {3}Zp_cast // {3}eqq2 modnn. - * suff: 2 * (q / 2)%nat <= q by lia. - apply /leP. - by apply Nat.mul_div_le. - + rewrite mul0n => _. - t1 (opp (Posz q)). - by rewrite abszN. - - t2 (opp (Posz (muln (leq (S (S (Zp_trunc q))) (addn a b)) (S (S (Zp_trunc q)))))). - - t2 (opp (Posz (muln (leq (S (S (Zp_trunc q))) (addn a b)) (S (S (Zp_trunc q)))))). - - have eqq: ((Zp_trunc q).+1 < a + b). - { move: (Nat.divmod_spec q 1 0 1) lta ltb. - case: (Nat.divmod q 1 0 1) => /= x u. - move/(_ (le_n _)) => [eqq1 eqq2] lta ltb. - rewrite !Nat.add_0_r in eqq1. - rewrite {1}Zp_cast //. - destruct u; simpl in *; lia. - } - rewrite eqq mul1n. - rewrite {1}Zp_cast // in eqq. - rewrite {3}Zp_cast //; lia. -Qed. - - -Lemma bounded_dvdn_cases (a q : nat) : - 1 < q -> - (q %| a)%N -> - a < 2 * q -> - (a == q) || (a == 0%nat). -Proof. - move=> qbig. - move/dvdnP => [k ->]. - rewrite ltn_mul2r. - move/andP=>[_ -]. - case: k; [| case]; lia. -Qed. - - -Lemma bounded_dvdn (a q : nat) : - 1 < q -> - (q %| a)%N -> - a < 2 * q -> - { c : nat | - a = (c * q)%N /\ c <= 1}. -Proof. - move=> qbig qdiva asmall. - move: (bounded_dvdn_cases a q qbig qdiva asmall). - case: (eqVneq a q) => /=. - - exists 1%nat. lia. - - exists 0%nat. lia. -Qed. - -Lemma bounded_divi (a : int) (q : nat) : - 1 < q -> - (q %| `|a|)%N -> - `|a| < 2 * q -> - { c : nat | - `| a | = (c * q)%N /\ c <= 1}. -Proof. - by apply bounded_dvdn. -Qed. - -Lemma bounded_divn_alt (a : nat) (q k : nat) : - 1 < q -> - (q %| a)%N -> - a <= k * q -> - { c : nat | a = (c * q)%N /\ c <= k}. -Proof. - move=> qbig. - exists (a %/ q)%nat. - by rewrite divnK // leq_divLR. -Qed. - -Lemma bounded_divi_alt (a : int) (q k : nat) : - 1 < q -> - (q %| `|a|)%N -> - `|a| <= k * q -> - { c : nat | `| a | = (c * q)%N /\ c <= k}. -Proof. - apply bounded_divn_alt. -Qed. - -Lemma absz_triang (a b : int) : - `|a + b| <= `|a| + `|b|. -Proof. - lia. -Qed. - -Lemma Zp_int (p:nat) (a : int) (pbig:1 inZp n - | Negz n => inZp (p - modn (n.+1) p) - end. -Proof. - destruct a. - - by rewrite /intmul Zp_nat. - - rewrite /intmul. - rewrite /intmul Zp_nat. - rewrite /opp /= /Zp_opp /=. - apply ord_inj => /=. - by rewrite Zp_cast //. -Qed. - -Lemma Z_q_0_dvd_abs {q : nat} (a : int) : - 1 < q -> - a %:~R = (0 : 'Z_q) -> - (q %| `|a|)%N. -Proof. - move => qbig. - rewrite Zp_int //. - case: a => n /=. - - move/(f_equal val) => /=. - rewrite Zp_cast //. - by move/eqP. - - move/(f_equal val) => /=. - rewrite Zp_cast //. - rewrite modnB; try lia. - rewrite modnn modn_mod. - lia. - Qed. - -Lemma zliftc_add2_ex {q : nat} (a b : 'Z_q) : - 1 < q -> - { c : nat | - `|zliftc (a + b)%R - ((zliftc a) + (zliftc b))%R | = (c * q)%N /\ - c <= 1}. -Proof. - intros. - apply bounded_divi; trivial. - - apply Z_q_0_dvd_abs; trivial. - rewrite rmorphD rmorphN !rmorphD /=. - rewrite !zliftc_valid //. - ring. - - move: (zliftc_bound a H) => a_bound. - move: (zliftc_bound b H) => b_bound. - move: (zliftc_bound (a + b) H) => ab_bound. - move: (absz_triang (zliftc a) (zliftc b)) => triang_ab. - move: (absz_triang (zliftc (a+b)) (- (zliftc a + zliftc b))) => triang_abc. - rewrite -abszN in triang_ab. - assert (q/2 + q/2 + q/2 < 2 * q)%N. - { - move: (Nat.divmod_spec q 1 0 1 (le_refl _)) => /=. - case: Nat.divmod => /= x y. - rewrite !Nat.add_0_r. - move=> [qeq ubound]; subst. - destruct y; lia. - } - lia. -Qed. - -Lemma zliftc_mul2 {q : nat} (a b : 'Z_q) : - 1 < q -> - `|(zliftc (a * b) - (zliftc a) * (zliftc b))%R | <= (q/2 * (1 + q/2))%N. -Proof. - intros. - move: (zliftc_bound a H) => a_bound. - move: (zliftc_bound b H) => b_bound. - move: (zliftc_bound (a * b) H) => ab_bound. - move: (absz_triang (zliftc (a*b)) (- (zliftc a * zliftc b))) => triang. - rewrite abszN in triang. - assert (`|zliftc a * zliftc b| <= (q/2) * (q/2)). - { - rewrite abszM. - by apply leq_mul. - } - lia. - Qed. - -Lemma zliftc_mul2_ex {q : nat} (a b : 'Z_q) : - 1 < q -> - { c : nat | - `|zliftc (a * b)%R - ((zliftc a) * (zliftc b))%R | = (c * q)%N /\ - c <= q/2}. -Proof. - intros. - apply bounded_divi_alt; trivial. - - apply Z_q_0_dvd_abs; trivial. - rewrite rmorphD rmorphN rmorphM /=. - rewrite !zliftc_valid //. - ring. - - generalize (zliftc_mul2 a b H); intros. - assert ((q / 2) * (1 + q / 2) <= (q/2) * q). - { - rewrite -Nat.div2_div leq_pmul2l. - - generalize (Nat.lt_div2 q); lia. - - generalize (div2_not_R0 q); lia. - } - lia. -Qed. - -Definition zlift {q : nat} (a : {poly 'Z_q}) : {poly int} := - map_poly zliftc a. - -Lemma zliftc0 (q : nat) : - zliftc (0 : 'Z_q) = 0. -Proof. - by rewrite /zliftc. -Qed. - -Lemma zlift0 {q : nat} (a : {poly 'Z_q}) : - a = 0 -> - zlift a = 0. -Proof. - intros. - rewrite /zlift /zliftc H. - apply map_poly0. -Qed. - -Lemma zlift0_alt {q : nat} : - zlift (0 : {poly 'Z_q}) = 0. -Proof. - rewrite /zlift /zliftc. - apply map_poly0. -Qed. - - -Definition icoef_maxnorm_ord (p : {poly int}):nat := \max_(j < seq.size p) `|p`_ j|. -Definition icoef_maxnorm_nat (p : {poly int}):nat := \max_(0 <= j < seq.size p) `|p`_ j|. -Definition icoef_maxnorm (p : {poly int}):nat := \max_(j < seq.size p) `|p`_ j|. - -Lemma icoef_maxnorm_conv (p : {poly int}) : - \max_(j < seq.size p) `|p`_ j| = \max_(0 <= j < seq.size p) `|p`_ j|. -Proof. - by rewrite big_mkord. -Qed. - -Lemma zlift_add2_ex {q : nat} (a b : {poly 'Z_q}) : - (1 < q) -> - { c : {poly int} | - zlift (a + b) = zlift a + zlift b + c /\ - icoef_maxnorm c <= q}. -Proof. - exists (zlift (a + b) - (zlift a + zlift b)). - split. - - ring. - - rewrite /icoef_maxnorm /zlift. - apply /bigmax_leqP => i _. - rewrite !(coefD,coefN). - rewrite !coef_map_id0; try apply zliftc0. - rewrite !coefD; try apply zlift0_alt. - by apply zliftc_add2. -Qed. - -Definition zlift_add2_perturb {q : nat} (qbig:(1 < q)) (a b : {poly 'Z_q}) : {poly int} - := sval (zlift_add2_ex a b qbig). - -Lemma zlift_add2_eq {q : nat} (qbig:(1 < q)) (a b : {poly 'Z_q}) - : zlift (a + b) = zlift a + zlift b + zlift_add2_perturb qbig a b. -Proof. - rewrite /zlift_add2_perturb. - case: (zlift_add2_perturb qbig a b). - case: zlift_add2_ex; intros; simpl. - tauto. -Qed. - -Lemma zlift_add2_perturb_small {q : nat} (qbig:(1 < q)) (a b : {poly 'Z_q}) - : icoef_maxnorm (zlift_add2_perturb qbig a b) <= q. -Proof. - rewrite /zlift_add2_perturb. - case: (zlift_add2_perturb qbig a b). - case: zlift_add2_ex; intros; simpl. - tauto. -Qed. - -Definition upi (c:R) : int := int_of_Z (up c). - -(* 0 <= rand < 1 *) -Definition ran_round (x rand : R) : int := - let hi := upi x in - if (hi%:~R - x < rand)%O then hi else (hi - 1). - -Definition nearest_round (x : R) : int := ran_round x (1/2)%R. - -Definition nearest_round_sgn (x : R) : int := - if (x < 0)%O then - (nearest_round (-x)) else nearest_round x. - -Definition nearest_round_int (n d : int) : int := nearest_round ((n %:~R)/(d %:~R))%R. -Definition nearest_round_int_sgn (n d : int) : int := nearest_round_sgn ((n %:~R)/(d %:~R))%R. - -Lemma inv_nat_pos (d : nat) : - d != 0%N -> - (inv d%:~R > (0 : R))%O. -Proof. - rewrite invr_gt0. - replace (zero (ssrnum.Num.NumDomain.porder_zmodType R_numDomainType)) with - (((0 : int)%:~R : R)) by lra. - rewrite ltr_int; lia. -Qed. - -Lemma div_nat_pos (n d : nat) : - d != 0%N -> - ((0 : R) <= n%:~R / d %:~R)%O. -Proof. - intros. - generalize (inv_nat_pos d H); intros. - generalize (ler_pM2l H0 (0 : R) (n %:~R)); intros. - rewrite mulr0 mulrC in H1. - rewrite H1. - replace (0 : R) with (((0 : int)%:~R : R) ) by lra. - rewrite ler_int; lia. -Qed. - -Lemma div_int_pos (d : nat) (n : int): - d != 0%N -> - ((0 : int) <= n)%O -> - ((0 : R) <= n%:~R / d %:~R)%O. -Proof. - intros. - destruct n. - - by apply div_nat_pos. - - lia. -Qed. - -Lemma div_negz_neg (n d : nat) : - d != 0%N -> - (((Negz n)%:~R : R)/ d%:~R < 0)%O. -Proof. - intros. - generalize (inv_nat_pos d H); intros. - generalize (ltr_pM2l H0 ((Negz n) %:~R) (0 : R)); intros. - rewrite mulr0 mulrC in H1. - rewrite H1. - replace (0 : R) with (((0 : int)%:~R : R) ) by lra. - rewrite ltr_int; lia. -Qed. - -Lemma div_int_neg (d : nat) (n : int): - d != 0%N -> - ((0 : int) > n)%O -> - ((0 : R) > ((n%:~R / d %:~R) : R))%O. -Proof. - intros. - destruct n. - - lia. - - by apply div_negz_neg. -Qed. - -Lemma div_int_neg' (d : nat) (n : int): - d != 0%N -> - ((0 : int) >= n)%O -> - ((0 : R) >= ((n%:~R / d %:~R) : R))%O. -Proof. - intros. - destruct n. - - assert (n = 0%N) by lia. - rewrite H1 mul0r. - by rewrite Order.POrderTheory.lexx. - - generalize (div_negz_neg n d H); intros. - lra. -Qed. - - -Lemma IZRE (n : Z) : IZR n = (int_of_Z n)%:~R. -Proof. - destruct n. - - by rewrite /intmul /= /IZR R00. - - by rewrite /IZR -INR_IPR INRE /int_of_Z /intmul. - - rewrite /IZR -INR_IPR INRE /int_of_Z /intmul /opp /=. - f_equal; f_equal. - lia. -Qed. - -Lemma IZREb (n : int) : n%:~R = IZR (ssrZ.Z_of_int n). -Proof. - by rewrite -{1}(Z_of_intK n) -IZRE. -Qed. - -Lemma up_int_add (n : Z) (c : R) : - up (Rplus (IZR n) c) = Zplus n (up c). -Proof. - symmetry. - destruct (archimed c). - apply tech_up; rewrite plus_IZR; coq_lra. -Qed. - -Lemma upi_intl (n : int) (c : R) : - upi ((n%:~R) + c) = n + upi c. -Proof. - rewrite /upi !IZREb up_int_add. - lia. -Qed. - -Lemma upi0 : - upi 0 = 1. -Proof. - rewrite /upi -(tech_up 0 1); try coq_lra. - lia. -Qed. - -Lemma nearest_round_int0 (d : int) : - nearest_round_int 0 d = 0. -Proof. - rewrite /nearest_round_int /nearest_round /ran_round. - rewrite mul0r upi0 oppr0 addr0 addrN. - rewrite /intmul /=. - rewrite ltr_pdivlMr; last by lra. - by rewrite mul1r /natmul/= gtrDl ltr10. -Qed. - -Lemma nearest_round_int_add (n1 : int) (c : R) : - nearest_round (n1 %:~R + c)%R = n1 + nearest_round c. -Proof. - rewrite /nearest_round /ran_round. - have ->: (upi (n1%:~R + c))%:~R - (n1%:~R + c)%R = ((upi c)%:~R - c). - { - rewrite upi_intl; lra. - } - case: Order.TotalTheory.ltP => _ ; by rewrite upi_intl; lra. -Qed. - -Lemma nearest_round_int_mul_add (n1 n2 d : int) : - d <> 0 -> - nearest_round_int (n1 * d + n2) d = n1 + nearest_round_int n2 d. -Proof. - intros. - rewrite /nearest_round_int -nearest_round_int_add. - rewrite intrD intrM mulrDl. - rewrite -mulrA divff. - - f_equal; lra. - - rewrite intr_eq0. - by apply/eqP. -Qed. - -Lemma nearest_round_int_mul_add_r (n1 n2 d : int) : - d <> 0 -> - nearest_round_int (n1 + d * n2) d = nearest_round_int n1 d + n2. -Proof. - intros. - rewrite mulrC addrC nearest_round_int_mul_add //. - lia. -Qed. - -Lemma nearest_round_int_mod (n1 n2 p1 p2 : int) : - p1 <> 0 -> - intdiv.modz n1 (p1 * p2) = intdiv.modz n2 (p1 * p2) -> - intdiv.modz (nearest_round_int n1 p1) p2 = intdiv.modz (nearest_round_int n2 p1) p2. -Proof. - intros. - assert (exists (h : int), - n1 = n2 + h * (p1 * p2)). - { - exists (intdiv.divz n1 (p1 * p2) - intdiv.divz n2 (p1 * p2)). - generalize (intdiv.divz_eq n1 (p1 * p2)); intros. - generalize (intdiv.divz_eq n2 (p1 * p2)); intros. - rewrite {1}H2 {1}H1 H0. - ring. - } - destruct H1. - rewrite H1 addrC (mulrC p1 p2) mulrA nearest_round_int_mul_add //. - by rewrite intdiv.modzMDl. -Qed. - -Lemma up_add2' (r1 r2 : R) : - (up(r1 + r2) = Z.add (up r1) (up r2) \/ up(r1 + r2) = Z.sub (Z.add (up r1) (up r2)) 1)%Z. -Proof. - destruct (archimed r1). - destruct (archimed r2). - destruct (archimed (r1 + r2)). - case: (boolP (Rleb ((IZR (up r1) + IZR (up r2)) - (r1 + r2)) 1)). - - intros. - left. - move /RlebP in p. - by rewrite (tech_up (r1 + r2) (up r1 + up r2)) // plus_IZR; coq_lra. - - intros. - right. - move /RlebP in i. - by rewrite (tech_up (r1 + r2) (up r1 + up r2 - 1)) // !plus_IZR opp_IZR; coq_lra. -Qed. - -Lemma up_add2 (r1 r2 : R) : - Z.abs_nat (Z.sub (up (r1 + r2)) (Z.add (up r1) (up r2))) <= 1. -Proof. - destruct (up_add2' r1 r2); rewrite H; lia. -Qed. - -Lemma up_le (r1 r2 : R) : - Rle r1 r2 -> - Z.le (up r1) (up r2). -Proof. - case: (boolP (Rleb (IZR (up r1)) (IZR(up r2)))). - - intros. - apply le_IZR. - by move /RlebP in p. - - intros. - move /RlebP in i. - assert (Rlt (IZR (up r2)) (IZR (up r1))). - { - coq_lra. - } - apply lt_IZR in H0. - assert (Z.le (Z.add (up r2) (1 : Z)) (up r1)). - { - lia. - } - apply IZR_le in H1. - rewrite /one /= plus_IZR in H1. - destruct (archimed r1). - destruct (archimed r2). - coq_lra. -Qed. - -Lemma upi_le (r1 r2 : R) : - (r1 <= r2)%O -> - (upi r1 <= upi r2)%O. -Proof. - intros. - move /RlebP in H. - apply up_le in H. - rewrite /upi. - lia. -Qed. - -Lemma int_of_Z_abs x : `|int_of_Z x| = Z.abs_nat x. -Proof. - rewrite /int_of_Z /absz /Z.abs_nat. - destruct x; trivial. - lia. -Qed. - -Lemma upi_add2' (n1 n2 : R) : - (upi (n1 + n2) = (upi n1 + upi n2))%R \/ - (upi (n1 + n2) = (upi n1 + upi n2) -1)%R. -Proof. - rewrite /upi. - destruct (up_add2' n1 n2). - - left. - by rewrite H raddfD /=. - - right. - rewrite H. - rewrite -Z.add_opp_r !raddfD /=. - lra. -Qed. - -Lemma upi_add2 (n1 n2 : R) : - `|upi (n1 + n2) - (upi n1 + upi n2)%R| <= 1. -Proof. - rewrite /upi. - rewrite (_: - int_of_Z (up (n1 + n2)) - (int_of_Z (up n1) + int_of_Z (up n2))%R = - int_of_Z (Z.sub (up (n1 + n2)) (Z.add (up n1) (up n2)))%Z). - - rewrite int_of_Z_abs. - apply up_add2. - - rewrite -Z.add_opp_r !raddfD /= raddfN /= raddfD /=. - lra. -Qed. - -Lemma ran_round_add2 (n1 n2 cutoff : R) : - ((0 : R) < cutoff)%O -> - (cutoff < (1 : R))%O -> - let sum := ran_round n1 cutoff + ran_round n2 cutoff in - `|ran_round (n1 + n2)%R cutoff - sum| <= 1. -Proof. - move=> cutoff_big cutoff_small. - rewrite /ran_round. - case: Order.TotalTheory.ltP=>lt1 ; case: Order.TotalTheory.ltP => lt2 ; case: Order.TotalTheory.ltP => lt3 - ; try (destruct (upi_add2' n1 n2); rewrite H; try lia; rewrite H raddfD /= in lt1; lra). -Qed. - -Lemma nearest_round_add2 (n1 n2 : R) : - let sum := nearest_round n1 + nearest_round n2 in - `|nearest_round (n1 + n2) - sum| <= 1. -Proof. - rewrite /nearest_round. - apply ran_round_add2; lra. -Qed. - -Lemma nearest_round_int_add2 (n1 n2 d : int) : - d <> 0 -> - let sum := nearest_round_int n1 d + nearest_round_int n2 d in - `|nearest_round_int (n1 + n2) d - sum| <= 1. -Proof. - move=> dn0. - rewrite /= /nearest_round_int. - rewrite (_:((n1 + n2)%:~R / d%:~R)%R = ((n1%:~R / d%:~R) + (n2%:~R / d%:~R))%R). - - apply nearest_round_add2. - - ring. -Qed. - -Lemma upi_bound (r : R) : - let diff := (upi r)%:~R - r in - Rlt 0%R diff /\ Rle diff 1%R. -Proof. - rewrite /upi -IZRE. - apply for_base_fp. -Qed. - -Lemma upi_bound_O (r : R) : - let diff := (upi r)%:~R - r in - ((0 : R) < diff)%O /\ (diff <= 1%R)%O. -Proof. - destruct (upi_bound r). - move /RltP in H. - by move /RleP in H0. -Qed. - -Lemma int_to_R_lt : - {mono (intmul (1 : R)) : x y / (x < y)%O}. -Proof. - move=> x y. - apply ltr_int. -Qed. - -Lemma int_to_R_le : - {mono (intr: int -> R) : x y / (x <= y)%O}. -Proof. - move => x y. - apply ler_int. -Qed. - -Lemma upi_nat_mul_bound_R (r : R) (n : nat) : - (((upi r)%:~R *+n - r*+n) <= n%:R)%O. -Proof. - destruct (upi_bound_O r). - assert ((0 : R) <= n%:R)%O. - { - by rewrite (ler_int _ 0 n). - } - apply (ler_wpM2r H1) in H0. - rewrite mul1r mulrBl in H0. - by replace ((upi r)%:~R *+ n - r *+ n ) with - ((upi r)%:~R * n%:R - r * n%:R) by ring. -Qed. - -Lemma Rabs_left_O (r : R) : - (r <= 0)%O -> Rabs r = -r. -Proof. - intros. - move /RleP in H. - rewrite /opp /=. - rewrite /zero /= in H. - case: (boolP (r == 0)); intros. - - rewrite (eqP p) Rabs_R0 /zero /=. - coq_lra. - - move /eqP in i. - apply Rabs_left. - rewrite /zero /= in i. - coq_lra. -Qed. - -Lemma nearest_round_bound_O (r : R) : - let diff := (nearest_round r)%:~R - r in - (Rabs diff <= 1/2)%O. -Proof. - rewrite /nearest_round /ran_round. - destruct (upi_bound_O r). - case: (boolP (Order.lt _ _)); intros. - - rewrite Rabs_right. - + by apply Order.POrderTheory.ltW. - + apply Rle_ge. - left. - apply /RltP. - by rewrite /zero /= in H. - - rewrite Order.TotalTheory.ltNge in i. - apply negbNE in i. - rewrite Rabs_left_O; lra. -Qed. - - -Lemma nearest_round_bound_O' (r : R) : - let diff := (nearest_round r)%:~R - r in - Rabs diff <> 1/2 -> - (Rabs diff < 1/2)%O. -Proof. - generalize (nearest_round_bound_O r); intros. - rewrite Order.POrderTheory.lt_neqAle. - unfold diff in *. - move /eqP in H0. - by apply /andP. -Qed. - -Lemma Rabs_n (r : R) (n : nat) : - (Rabs r) *+n = Rabs (r *+ n). -Proof. - rewrite -mulr_natr -(mulr_natr r n) RnormM. - rewrite /mul /=. - f_equal. - rewrite Rabs_right //. - apply Rle_ge. - apply /RleP. - replace (IZR Z0) with (0%:R : R). - - rewrite ler_nat. - lia. - - by rewrite mulr0n /zero /=. - Qed. - -Lemma nearest_round_nat_mul_bound_R (r : R) (n : nat) : - (Rabs (add ((nearest_round r)%:~R *+n) (opp r*+n)) <= (1/2)*+n)%O. -Proof. - generalize (nearest_round_bound_O r); intros. - apply (ler_wMn2r (R := R_numDomainType) n) in H. - apply /RleP. - move /RleP in H. - eapply Rle_trans; [|apply H]. - right. - rewrite Rabs_n. - f_equal. - ring. -Qed. - -Lemma nearest_round_nat_mul_bound_R'_S (r : R) (n : nat) : - Rabs ((nearest_round r)%:~R - r) <> 1/2 -> - (Rabs (add ((nearest_round r)%:~R *+n.+1) (opp r*+n.+1)) < (1/2)*+n.+1)%O. -Proof. - intros. - generalize (nearest_round_bound_O' r H); intros. - apply (ltr_wMn2r (R := R_numDomainType) n.+1) in H0. - apply /RltP. - assert (0 < n.+1) by lia. - rewrite H1 in H0. - move /RltP in H0. - eapply Rle_lt_trans; [|apply H0]. - right. - rewrite Rabs_n. - f_equal. - ring. -Qed. - -Lemma upi_nat_mul_bound_R0_lt (r : R) (n : nat) : - ((0 : R) < ((upi r)%:~R *+n.+1 - r*+n.+1))%O. -Proof. - destruct (upi_bound_O r). - assert (0 < n.+1) by lia. - apply (ltr_wpMn2r H1 (R := R_numDomainType)) in H. - rewrite mul0rn in H. - by replace ((upi r)%:~R *+ n.+1 - r *+ n.+1 ) with - (((upi r)%:~R - r) *+ n.+1) by ring. -Qed. - -Lemma upi_nat_mul_bound_R0_lt_alt (r : R) (n : nat) : - (r *+ n.+1 < (upi r)%:~R *+n.+1)%O. -Proof. - generalize (upi_nat_mul_bound_R0_lt r n); intros. - lra. -Qed. - -Lemma upi_nat_mul_bound_R0_lt_alt_alt (r : R) (n : nat) : - (r *+ n.+1 < ((upi r)*+n.+1)%:~R)%O. -Proof. - generalize (upi_nat_mul_bound_R0_lt_alt r n); intros. - by replace ((upi r *+ n.+1)%:~R)%R with ((((upi r)%:~R) : R) *+ n.+1) by ring. -Qed. - -Lemma up_lt_le (r : R) (i : Z) : - (Rlt r (IZR i)) -> - Z.le (up r) i. -Proof. - intros. - destruct (archimed r); intros. - assert (Rle (IZR (up r) - 1) r) by coq_lra. - assert (IZR (up r) - 1 = IZR ((up r) - 1)). - { - rewrite /opp /add /one /=. - rewrite minus_IZR. - coq_lra. - } - assert (~(Z.lt i (up r))). - { - intros ?. - assert (Z.le i ((up r) - 1)) by lia. - assert (Rle (IZR i) (IZR ((up r) -1))). - { - by apply IZR_le. - } - rewrite -H3 in H6. - assert (Rle (IZR i) r). - { - eapply Rle_trans. - apply H6. - apply H2. - } - coq_lra. - } - lia. -Qed. - -Lemma upi_lt_le (r : R) (i : int) : - (r < i%:~R)%O -> - (upi r <= i)%O. -Proof. - intros. - generalize (up_lt_le r (Z_of_int i)); intros. - rewrite -IZREb in H0. - move /RltP in H. - specialize (H0 H). - rewrite /upi. - lia. -Qed. - -Lemma upi_nat_mul_bound_R0_le (r : R) (n : nat) : - ((0 : R) <= ((upi r)%:~R *+n - r*+n))%O. -Proof. - destruct n. - - lra. - - apply Order.POrderTheory.ltW. - apply upi_nat_mul_bound_R0_lt. -Qed. - -Lemma upi_nat_mul_le (r : R) (n : nat) : - (upi (r *+ n.+1) <= upi(r)*+n.+1)%O. -Proof. - apply upi_lt_le. - apply upi_nat_mul_bound_R0_lt_alt_alt. -Qed. - -Definition upiR (r: R) : R := (upi r) %:~R. - -Lemma eq_rel2 {A B:Type} {RR:A->B->bool} {a: A} {b:B} {c:A} {d:B} : - RR a b -> - a = c -> - b = d -> - RR c d. -Proof. - congruence. -Qed. - -Lemma upi_nat_mul' (r : R) (n : nat) : - (upi(r)*+n.+1 - upi (r *+ n.+1) < n.+1%:R)%O. -Proof. - destruct (upi_bound_O (r *+ n.+1)). - generalize (upi_nat_mul_bound_R r n.+1); intros. - assert ((upiR r) *+ n.+1 - (upiR (r*+n.+1)) < (n.+1%:R))%O. - { - replace ((upiR r) *+ n.+1 - upiR (r *+ n.+1)) with - ( (upiR r) *+ n.+1 - r*+n.+1 - (upiR (r *+ n.+1) - r*+n.+1)) by ring. - eapply Order.POrderTheory.lt_le_trans; [| apply H1]. - rewrite /upiR. - - have: (forall (a b : R), ((0 : R) < b)%O -> (a - b < a)%O) by (intros; lra). - by apply. - } - rewrite /upiR in H2. - rewrite -int_to_R_lt. - apply (eq_rel2 H2); ring. -Qed. - -Lemma upi_nat_mul'' (r : R) (n : nat) : - ((0:int) <= upi(r)*+n.+1 - upi (r *+ n.+1) < n.+1%:R)%O. -Proof. - apply /andP. - split. - - generalize (upi_nat_mul_le r n); intros. - lia. - - apply upi_nat_mul'. -Qed. - -Lemma upi_nat_mul (r : R) (n : nat) : - (upi(r)*+n - upi (r *+ n) < n%:R)%O. -Proof. - destruct n. - - by rewrite !mulr0n upi0 add0r. - - apply upi_nat_mul'. -Qed. - -Lemma upi_nat_mul_pos (r : R) (n : nat) : - ((0 : int) <= upi(r)*+n.+1 - upi (r *+ n.+1))%O. -Proof. - generalize (upi_nat_mul_le r n); intros. - lia. -Qed. - -Lemma upi_nat_mul_abs_S (r : R) (n : nat) : - `|upi (r *+ n.+1) - upi(r)*+n.+1| < n.+1. -Proof. - rewrite distnC. - assert (Posz `|upi r *+ n.+1 - upi (r *+ n.+1)| < Posz n.+1)%O. - { - rewrite gez0_abs; [| apply (upi_nat_mul_pos r n)]. - generalize (upi_nat_mul' r n); intros. - lia. - } - lia. -Qed. - -Lemma upi_nat_mul_abs (r : R) (n : nat) : - `|upi (r *+ n) - upi(r)*+n| <= n.+1. -Proof. - destruct n. - - by rewrite !mulr0n upi0 addr0. - - generalize (upi_nat_mul_abs_S r n). - lia. -Qed. - -Lemma Rabs_opp_sym (x y : R) : - Rabs (add x (opp y)) = Rabs (add y (opp x)). -Proof. - generalize (Rabs_minus_sym x y); intros. - rewrite /add /opp /=. - by rewrite /Rminus in H. -Qed. - -Lemma Rplus_add (x y : R) : - Rplus x y = add x y. -Proof. - by rewrite /add /=. -Qed. - -Lemma Ropp_opp(x : R) : - Ropp x = opp x. -Proof. - by rewrite /opp /=. -Qed. - -Lemma Rmult_mul (x y : R) : - Rmult x y = mul x y. -Proof. - by rewrite /mul /=. -Qed. - -Lemma Rabs_absz (j : int) : - Rabs (j %:~R) = (`|j|%:~R). -Proof. - rewrite /absz. - destruct j. - - rewrite Rabs_right //. - apply Rle_ge. - apply /RleP. - replace (IZR Z0) with ((0 %:~R):R). - + rewrite int_to_R_le; lia. - + by rewrite /zero /=. - - rewrite NegzE Rabs_left. - + rewrite Ropp_opp. - by rewrite mulrNz opprK. - + apply /RltP. - replace (IZR Z0) with ((0 %:~R):R). - * rewrite int_to_R_lt; lia. - * by rewrite /zero /=. -Qed. - -Lemma half_lt (j : int) (n : nat) : - (Posz (n/2)%N < j)%O -> - (Posz n < 2 * j)%O. -Proof. - intros. - generalize (Nat.div_mod_eq n 2); intros. - case: (boolP (n mod 2 == 0%N)); intros; try lia. - generalize (Nat.mod_upper_bound n 2); intros. - lia. -Qed. - -Lemma half_int_to_R_le (j : int) (n : nat): - ((j %:~R : R) <= (1/2 : R) *+ n)%O -> - (j <= (n / 2)%N)%O. -Proof. - intros. - case: (boolP (Order.le _ _)); intros; trivial. - rewrite -Order.TotalTheory.ltNge in i. - assert ((n : int) < 2*j)%O. - { - by apply half_lt. - } - rewrite -(@ltr_int R_numDomainType) in H0. - assert ( (n%:~R : R) / 2 < j%:~R)%O by lra. - rewrite -(@ltr_int R_numDomainType) in i. - assert ( (1 / 2 : R) *+ n < j%:~R)%O. - { - assert ((1 / 2 : R) *+ n = (n%:~R : R) / 2) by ring. - by rewrite H2. - } - lra. -Qed. - -Lemma half_int_to_R_le_alt (j : int) (n : nat): - ((j %:~R : R) <= (1/2 : R) *+ n)%O -> - (j <= (n ./2)%N)%O. -Proof. - intros. - case: (boolP (Order.le _ _)); intros; trivial. - assert ((n : int) < 2*j)%O by lia. - rewrite -(@ltr_int R_numDomainType) in H0. - assert ( (1 / 2 : R) *+ n < j%:~R)%O. - { - assert ((1 / 2 : R) *+ n = (n%:~R : R) / 2) by ring. - lra. - } - generalize (Order.POrderTheory.le_lt_trans H H1). - by rewrite Order.POrderTheory.ltxx. -Qed. - -Lemma half_int_to_R_lt (j : int) (n : nat): - ((j %:~R : R) < (1/2 : R) *+ n.+1)%O -> - (j <= (n./2)%N)%O. -Proof. - intros. - case: (boolP (odd n)); intros. - - assert (n.+1 = (2 * (n.+1 %/ 2))%N). - { - generalize (divn_eq n.+1 2); intros. - assert (n.+1 %% 2 = 0)%N. - { - rewrite modn2 oddS p //. - } - rewrite H1 addn0 in H0. - by rewrite {1}H0 mulnC. - } - rewrite H0 in H. - assert (1 / 2 *+ (2 * (n.+1 %/ 2)) = ((Posz (n.+1%/2))%:~R : R)) by (by field). - rewrite H1 ltr_int in H. - lia. - - apply Order.POrderTheory.ltW in H. - generalize (half_int_to_R_le_alt _ _ H); intros. - lia. -Qed. - -Lemma Rabs_absz_half (j : int) (n : nat) : - (Rabs (j %:~R) <= (1/2 : R) *+ n)%O -> - (`|j| <= (n/2)%N)%O. -Proof. - intros. - rewrite Rabs_absz in H. - by apply half_int_to_R_le in H. -Qed. - -Lemma Rabs_absz_half_le (j : int) (n : nat) : - (Rabs (j %:~R) <= (1/2 : R) *+ n)%O -> - (`|j| <= (n./2)%N)%O. -Proof. - intros. - rewrite Rabs_absz in H. - by apply half_int_to_R_le_alt in H. -Qed. - -Lemma Rabs_absz_half_lt (j : int) (n : nat) : - (Rabs (j %:~R) < (1/2 : R) *+ n.+1)%O -> - (`|j| <= (n./2)%N)%O. -Proof. - intros. - rewrite Rabs_absz in H. - by apply half_int_to_R_lt in H. -Qed. - -Lemma nearest_round_mul_abs_nat (r : R) (n : nat) : - `|nearest_round (r *+ n) - nearest_round(r)*+n| <= (n.+1)/2. -Proof. - generalize (nearest_round_nat_mul_bound_R r n); intros. - generalize (nearest_round_bound_O (r *+ n)); intros. - rewrite /= Rabs_opp_sym in H0. - generalize (lerD H H0); intros. - generalize (ler_normD - ((nearest_round r)%:~R *+ n + - r *+ n)%R - (r *+ n - (nearest_round (r *+ n))%:~R)%R ); intros. - generalize (Order.POrderTheory.le_trans H2 H1); intros. - rewrite -!addrA (addrA (-r*+n) _ _) in H3. - rewrite mulNrn addNr add0r in H3. - rewrite distnC. - rewrite -mulrSr -rmorphMn -rmorphN -rmorphD /= in H3. - by apply Rabs_absz_half in H3. -Qed. - -Lemma nearest_round_mul_abs_nat_alt (r : R) (n : nat) : - `|nearest_round (r *+ n) - nearest_round(r)*+n| <= (n.+1)./2. -Proof. - generalize (nearest_round_nat_mul_bound_R r n); intros. - generalize (nearest_round_bound_O (r *+ n)); intros. - rewrite /= Rabs_opp_sym in H0. - generalize (lerD H H0); intros. - generalize (ler_normD - ((nearest_round r)%:~R *+ n + - r *+ n)%R - (r *+ n - (nearest_round (r *+ n))%:~R)%R ); intros. - generalize (Order.POrderTheory.le_trans H2 H1); intros. - rewrite -!addrA (addrA (-r*+n) _ _) in H3. - rewrite mulNrn addNr add0r in H3. - rewrite distnC. - rewrite -mulrSr -rmorphMn -rmorphN -rmorphD /= in H3. - by apply Rabs_absz_half_le in H3. -Qed. - -Lemma nearest_round_0 : - nearest_round 0 = 0. -Proof. - rewrite /nearest_round /ran_round upi0 oppr0 addr0. - replace (((1%:~R : R )< 1 / 2)%O) with false by lra. - lia. -Qed. - -Lemma nearest_round_int_val (n : int) : - nearest_round (n %:~R) = n. -Proof. - rewrite /nearest_round /ran_round. - generalize (upi_intl n 0); intros. - rewrite Rplus_add addr0 upi0 in H. - rewrite H intrD addrC addrA addNr add0r. - replace ((1%:~R : R) < 1 / 2)%O with false by lra. - lia. -Qed. - -Lemma nearest_round_half (r : R) : - (upi r)%:~R - r = 1 / 2 <-> - Rabs ((nearest_round r)%:~R - r) = 1 / 2. -Proof. - rewrite /nearest_round /ran_round. - split; intros. - - rewrite H. - assert ((((1 / 2):R) < 1 / 2)%O = false) by lra. - rewrite H0 rmorphD /= /Rminus Rplus_add Ropp_opp. - rewrite -addrA addrC -addrA (addrC (- r) _) H. - replace ( ((-1)%:~R : R) + 1 / 2) with (- (1 / 2):R) by lra. - rewrite -Ropp_opp Rabs_Ropp Rabs_right // mul1r. - apply Rle_ge; left. - apply /RltP. - rewrite invr_gt0. - lra. - - case: (boolP ((upi r)%:~R - r < 1 / 2)%O); intros. - + rewrite p in H. - rewrite /Rminus Ropp_opp Rplus_add in H. - rewrite Rabs_right // in H. - left. - apply upi_bound. - + assert (((upi r)%:~R - r < 1 / 2)%O = false) by lra. - rewrite H0 in H. - rewrite rmorphD rmorphN /= in H. - destruct (upi_bound_O r). - rewrite /Rminus Ropp_opp Rplus_add in H. - case: (boolP (Order.le (0 : R) ((upi r)%:~R - 1%:~R - r)%R )); intros. - * move /RleP in p. - rewrite Rabs_right in H. - -- move /RleP in p. - lra. - -- by apply Rle_ge. - * rewrite -Order.TotalTheory.ltNge in i0. - rewrite Rabs_left in H. - -- rewrite Ropp_opp in H. - lra. - -- by apply /RltP. -Qed. - -Lemma nearest_round_not_half (r : R) : - (upi r)%:~R - r <> 1 / 2 -> - (Rabs ((nearest_round r)%:~R - r) < 1 / 2)%O. -Proof. - intros. - rewrite Order.POrderTheory.lt_def. - apply /andP. - split. - - apply /eqP. - intros ?. - symmetry in H0. - by rewrite -nearest_round_half in H0. - - apply nearest_round_bound_O. -Qed. - -Lemma nearest_round_not_half' (r : R) : - (Rabs ((nearest_round r)%:~R - r) <> 1 / 2) -> - (Rabs ((nearest_round r)%:~R - r) < 1 / 2)%O. -Proof. - intros. - rewrite Order.POrderTheory.lt_def. - apply /andP. - split. - - lra. - - apply nearest_round_bound_O. -Qed. - -Lemma nearest_round_half_val (r : R) : - Rabs ((nearest_round r)%:~R - r) = 1 / 2 -> - (nearest_round r)%:~R - r = -(1 / 2). -Proof. - intros. - rewrite -nearest_round_half in H. - rewrite /nearest_round /ran_round. - rewrite H. - case: (boolP (Order.lt _ _)); lra. -Qed. - -Lemma nearest_round_mul_abs_nat_half (r : R) (n : nat) : - Rabs ((nearest_round r)%:~R - r) = 1 / 2 -> - Rabs ((nearest_round (r *+ n.+1))%:~R - r *+ n.+1)%R = 1 / 2 -> - (`|nearest_round (r *+ n.+1) - nearest_round(r)*+n.+1| <= n./2)%O. -Proof. - intros. - rewrite /Rminus Ropp_opp Rplus_add in H. - apply nearest_round_half_val in H0. - apply nearest_round_half_val in H. - apply (f_equal (fun z => z *+ n.+1)) in H. - assert (((nearest_round (r *+ n.+1))%:~R:R) - (nearest_round r)%:~R *+ n.+1 = -(- (1 / 2) *+ n.+1) - (1 / 2)). - { - rewrite -H -H0. - ring. - } - replace ( - (1 / 2) *- n.+1 - 1 / 2) with (((1 / 2) *+ n):R) in H1 by ring. - rewrite -rmorphMn -rmorphN -rmorphD /= in H1. - apply (f_equal (fun z => Rabs z)) in H1. - rewrite (Rabs_right (1 / 2 *+ n)) in H1. - - move /eqP in H1. - rewrite Order.POrderTheory.eq_le in H1. - move /andP in H1. - destruct H1. - by apply Rabs_absz_half_le in H1. - - apply Rle_ge. - rewrite -mulrnAl. - apply /RleP. - apply mulr_ge0; [|lra]. - apply mulrn_wge0; lra. - Qed. - -Lemma nearest_round_mul_abs_nat_not_half1 (r : R) (n : nat) : - Rabs ((nearest_round r)%:~R - r) <> 1 / 2 -> - `|nearest_round (r *+ n.+1) - nearest_round(r)*+n.+1| <= n.+1./2. -Proof. - intros. - generalize (nearest_round_nat_mul_bound_R'_S r n H); intros. - generalize (nearest_round_bound_O (r *+ n.+1)); intros. - simpl in H1. - rewrite Rabs_opp_sym in H1. - generalize (ltr_leD H0 H1); intros. - generalize (ler_normD - ((nearest_round r)%:~R *+ n.+1 + - r *+ n.+1)%R - (r *+ n.+1 - (nearest_round (r *+ n.+1))%:~R)%R ); intros. - generalize (Order.POrderTheory.le_lt_trans H3 H2); intros. - rewrite -addrA (addrA _ (r *+ n.+1) _) in H4. - replace (- r *+ n.+1 + r *+ n.+1) with (0:R) in H4 by ring. - rewrite add0r -mulrSr in H4. - rewrite distnC. - rewrite -rmorphMn -rmorphN -rmorphD /= in H4. - by apply Rabs_absz_half_lt in H4. -Qed. - -Lemma nearest_round_mul_abs_nat_not_half2 (r : R) (n : nat) : - Rabs ((nearest_round (r *+ n.+1))%:~R - r *+ n.+1)%R <> 1 / 2 -> - `|nearest_round (r *+ n.+1) - nearest_round(r)*+n.+1| <= n.+1./2. -Proof. - intros. - generalize (nearest_round_nat_mul_bound_R r n.+1); intros. - generalize (nearest_round_bound_O' (r *+ n.+1) H); intros. - rewrite mulNrn Rabs_opp_sym in H0. - generalize (ltr_leD H1 H0); intros. - generalize (ler_normD - ((nearest_round (r *+ n.+1))%:~R - r *+ n.+1)%R - (r *+ n.+1 - (nearest_round r)%:~R *+ n.+1)%R ); intros. - generalize (Order.POrderTheory.le_lt_trans H3 H2); intros. - rewrite -addrA (addrA _ (r *+ n.+1) _) in H4. - replace (r *- n.+1 + r *+ n.+1) with (0:R) in H4 by ring. - rewrite add0r (addrC (1 / 2) _) -mulrSr in H4. - rewrite -rmorphMn -rmorphN -rmorphD /= in H4. - by apply Rabs_absz_half_lt in H4. -Qed. - -Lemma nearest_round_mul_abs_nat_0 (r : R) : - `|nearest_round (r *+ 0) - nearest_round(r)*+0| = 0%N. -Proof. - by rewrite mulr0n nearest_round_0 mulr0n oppr0 addr0 /=. -Qed. - -Lemma nearest_round_mul_abs_nat' (r : R) (n : nat) : - `|nearest_round (r *+ n) - nearest_round(r)*+n| <= n./2. -Proof. - destruct n. - - rewrite nearest_round_mul_abs_nat_0. - lia. - - case: (boolP (Rabs ((nearest_round r)%:~R - r) == 1 / 2)). - + case: (boolP (Rabs ((nearest_round (r *+ n.+1))%:~R - r *+ n.+1)%R == 1 / 2)). - * intros. - move /eqP in p. - move /eqP in p0. - generalize (nearest_round_mul_abs_nat_half r n p0 p). - lia. - * intros. - move /eqP in i. - by apply nearest_round_mul_abs_nat_not_half2. - + intros. - move /eqP in i. - by apply nearest_round_mul_abs_nat_not_half1. -Qed. - -Lemma nearest_round_mul_abs_nat_opp (r : R) (n : nat) : - `|nearest_round r *+ n + nearest_round (r *- n)| <= - n.+1 / 2. -Proof. - replace (r *-n) with ((opp r) *+ n) by ring. - generalize (nearest_round_nat_mul_bound_R r n); intros. - generalize (nearest_round_bound_O ((opp r) *+ n)); intros. - simpl in H0. - replace (opp (natmul (opp r) n)) with (natmul r n) in H0 by ring. - generalize (lerD H H0); intros. - generalize (ler_normD - ((nearest_round r)%:~R *+ n + - r *+ n)%R - ((nearest_round (- r *+ n))%:~R + r *+ n)%R); intros. - generalize (Order.POrderTheory.le_trans H2 H1); intros. - rewrite -addrA (addrC (-r *+ n) _) -addrA in H3. - replace (r *+ n + (-r) *+n) with (0 : R) in H3 by ring. - rewrite addr0 -mulrSr in H3. - rewrite -rmorphMn -rmorphD /= in H3. - by apply Rabs_absz_half in H3. -Qed. - -Lemma nearest_round_mul_abs_nat_opp_not_half1 (r : R) (n : nat) : - Rabs ((nearest_round r)%:~R - r) <> 1 / 2 -> - `|nearest_round r *+ n.+1 + nearest_round (r *- n.+1)| <= n.+1 ./2. -Proof. - intros. - generalize (nearest_round_nat_mul_bound_R'_S r n H); intros. - generalize (nearest_round_bound_O ((opp r) *+ n.+1)); intros. - rewrite mulNrn in H0. - rewrite mulNrn opprK /= in H1. - generalize (ltr_leD H0 H1); intros. - generalize (ler_normD - ((nearest_round r)%:~R *+ n.+1 - r *+ n.+1)%R - ((nearest_round (r *- n.+1))%:~R + r *+ n.+1)%R); intros. - generalize (Order.POrderTheory.le_lt_trans H3 H2); intros. - rewrite -addrA (addrC _(r *+ n.+1)) (addrA (r*-n.+1) _ _) addNr add0r in H4. - rewrite -mulrSr -rmorphMn -rmorphD /= in H4. - by apply Rabs_absz_half_lt in H4. -Qed. - -Lemma nearest_round_mul_abs_nat_opp_not_half2 (r : R) (n : nat) : - Rabs ((nearest_round (- r *+ n.+1))%:~R - - r *+ n.+1)%R <> 1 / 2 -> - - `|nearest_round r *+ n.+1 + nearest_round (r *- n.+1)| <= n.+1 ./2. -Proof. - intros. - generalize (nearest_round_nat_mul_bound_R r n.+1); intros. - generalize (nearest_round_bound_O' ((opp r) *+ n.+1) H); intros. - rewrite mulNrn in H0. - rewrite mulNrn opprK in H1. - generalize (ltr_leD H1 H0); intros. - generalize (ler_normD - ((nearest_round (r *- n.+1))%:~R + r *+ n.+1)%R - ((nearest_round r)%:~R *+ n.+1 - r *+ n.+1)%R); intros. - generalize (Order.POrderTheory.le_lt_trans H3 H2); intros. - rewrite -addrA in H4. - rewrite (addrC ((nearest_round r)%:~R *+ n.+1) _) in H4. - rewrite (addrA (r*+n.+1) _ _) addrN add0r in H4. - rewrite -rmorphMn -rmorphD -mulrS in H4. - apply Rabs_absz_half_lt in H4. - lia. -Qed. - -(* - r = IZR (Int_part r) + frac_part r. - 0 <= frac_part r < 1. -*) - -Lemma up_opp_Z (r : R) : - IZR (up r) = Rplus r 1 -> - up (-r) = Zplus (- (up r)) (Zpos 2). -Proof. - intros. - symmetry. - apply tech_up. - - rewrite plus_IZR opp_IZR H. - coq_lra. - - rewrite plus_IZR opp_IZR H. - coq_lra. -Qed. - -Lemma up_opp_nZ (r : R) : - IZR (up r) <> Rplus r 1 -> - up (-r) = Zplus (- (up r)) (Zpos 1). -Proof. - intros. - symmetry. - destruct (archimed r). - apply tech_up. - - rewrite plus_IZR opp_IZR. - coq_lra. - - rewrite plus_IZR opp_IZR. - coq_lra. -Qed. - -Lemma upi_opp_int (r : R) : - (upi r)%:~R = r + 1 -> - upi (opp r) = - (upi r) + 2. -Proof. - intros. - rewrite /upi up_opp_Z. - - lia. - - by rewrite -IZRE /add /one /= in H. -Qed. - -Lemma upi_opp_nint (r : R) : - (upi r)%:~R <> r + 1 -> - upi (opp r) = - (upi r) + 1. -Proof. - intros. - rewrite /upi up_opp_nZ. - - lia. - - by rewrite -IZRE /add /one /= in H. -Qed. - -Lemma nearest_round_opp_not_half (r : R) : - (upi r)%:~R - r <> 1/2 -> - nearest_round (opp r) = - nearest_round r. -Proof. - rewrite /nearest_round /ran_round. - case: (boolP ((upi r)%:~R == r + 1)); intros. - - move /eqP in p. - rewrite (upi_opp_int r p) opprK rmorphD rmorphN /= p. - case: (boolP (Order.lt _ _)); intros. - + by assert false by lra. - + case: (boolP (Order.lt _ _)); intros. - * by assert false by lra. - * ring. - - move /eqP in i. - rewrite (upi_opp_nint r i) opprK rmorphD rmorphN /=. - case: (boolP (Order.lt _ _)); intros. - + case: (boolP (Order.lt _ _)); intros. - * by assert false by lra. - * ring. - + case: (boolP (Order.lt _ _)); intros. - * ring. - * by assert false by lra. -Qed. - -Lemma nearest_round_opp_half (r : R) : - (upi r)%:~R - r = 1/2 -> - nearest_round (-r) = - nearest_round r - 1. -Proof. - rewrite /nearest_round /ran_round. - intros. - assert( (upi r)%:~R <> r + 1 ) by lra. - rewrite (upi_opp_nint r H0) rmorphD rmorphN /= opprK. - case : (boolP (Order.lt _ _)); intros. - - by assert false by lra. - - case : (boolP (Order.lt _ _)); intros. - + by assert false by lra. - + ring. -Qed. - -Lemma upi_mul_abs (r : R) (n : int) : - `|upi (r *~ n) - (upi r) *~n| <= `|n|+1. -Proof. - destruct n; simpl. - - rewrite -!pmulrn. - generalize (upi_nat_mul_abs r n); intros. - lia. - - rewrite distnC /intmul /=. - generalize (upi_nat_mul_abs_S r n); intros. - rewrite distnC in H. - case: (boolP ((upi (r *+ n.+1))%:~R == (r*+n.+1)+1)); intros. - + rewrite (upi_opp_int _ (eqP p)) opprD opprK addrC distnC. - rewrite opprB (addrC 2 _) addrA. - generalize (absz_triang (upi r *+ n.+1 - upi (r *+ n.+1)) 2); intros. - lia. - + move /eqP in i. - rewrite (upi_opp_nint _ i) opprD opprK addrC distnC. - rewrite opprB (addrC 1 _) addrA. - generalize (absz_triang (upi r *+ n.+1 - upi (r *+ n.+1)) 1); intros. - lia. -Qed. - -Lemma upi_mul_abs_alt (r : R) (n : int) : - `|(upi (mul r n%:~R) - (upi r) * n)%Z| <= `|n|+1. -Proof. - generalize (upi_mul_abs r n); intros. - replace (upi (mul r n%:~R) - upi r * n)%R with - (upi (r *~ n) - upi r *~ n); trivial. - f_equal. - - f_equal; lra. - - f_equal; f_equal; lia. -Qed. - -Lemma nearest_round_mul_abs (r : R) (n : int) : - `|nearest_round (r *~ n) - nearest_round(r)*~n| <= (`|n|+1)/2. -Proof. - destruct n. - - rewrite -!pmulrn. - generalize (nearest_round_mul_abs_nat r n); intros. - replace (`|n| + 1)%N with (n.+1) by lia. - lia. - - rewrite distnC /intmul NegzE. - replace (`|-(Posz n.+1)|+1)%N with (n.+2) by lia. - rewrite -opprD abszN. - apply nearest_round_mul_abs_nat_opp. -Qed. - -Lemma nearest_round_int_mul (n1 n2 d : int) : - d <> 0 -> - `|(nearest_round_int (n1 * n2) d - nearest_round_int n1 d * n2)%R| <= (`|n2| + 1)/2. -Proof. - rewrite /nearest_round_int intrM. - rewrite (_: (n1%:~R * n2%:~R / d%:~R) = ((n1%:~R / d%:~R)%R * n2%:~R )); last by lra. - move: {n1} (((n1%:~R / d%:~R)%R)) => n1 _ {d}. - replace (n1 * n2%:~R) with (n1*~n2) by ring. - replace (nearest_round n1 * n2) with (nearest_round n1 *~ n2). - apply nearest_round_mul_abs. - destruct n2; ring. -Qed. - -Lemma nearest_round_int_add2_ex (n1 n2 d : int) : - d <> 0 -> - let sum := nearest_round_int n1 d + nearest_round_int n2 d in - { n3 : int | - nearest_round_int (n1 + n2) d = sum + n3 /\ - `|n3| <= 1}. -Proof. - intros. - exists (nearest_round_int (n1 + n2) d - sum). - split; try lia. - by apply nearest_round_int_add2. -Qed. - -Lemma nearest_round_int_mul_ex (n1 n2 d : int) : - d <> 0 -> - { n3 : int | - nearest_round_int (n1 * n2) d = nearest_round_int n1 d * n2 + n3 /\ - `|n3| <= (`|n2|+1)/2}. -Proof. - intros. - exists (nearest_round_int (n1 * n2) d - nearest_round_int n1 d * n2). - split; try lia. - by apply nearest_round_int_mul. -Qed. - -Definition div_round (a : {poly int}) (d : int) : {poly int} := - map_poly (fun c => nearest_round_int c d) a. - -Lemma div_round0 (den : int) : - div_round (0 : {poly int}) den = 0. -Proof. - apply map_poly0. -Qed. - -Lemma nth_map_default: - forall [T1 : Type] (x1 : T1) [T2 : Type] (x2 : T2) (f : T1 -> T2) [n : nat] [s : seq T1], - f x1 = x2 -> - nth x2 [seq f i | i <- s] n = f (nth x1 s n). -Proof. - intros. - case/orP: (leqVgt (size s) n) => ineq. - - by rewrite !nth_default // size_map. - - by rewrite (nth_map x1). -Qed. - -Lemma div_round_mul_add (a b : {poly int}) (d : int) : - d <> 0 -> - div_round (a + d *: b) d = div_round a d + b. -Proof. - intros. - rewrite /div_round !map_polyE -polyP => i. - rewrite coefD !coef_Poly. - rewrite !(nth_map_default 0 0); try by rewrite nearest_round_int0. - by rewrite coefD -nearest_round_int_mul_add_r // coefZ. -Qed. - -Lemma div_round_muln_add (a b : {poly int}) (d : nat) : - d <> 0%N -> - div_round (a + b *+ d) d = div_round a d + b. -Proof. - intros. - rewrite -div_round_mul_add; try lia. - by rewrite -scaler_nat /GRing.scale /= !scale_polyE natz. -Qed. - -Lemma div_round_muln (b : {poly int}) (d : nat) : - d <> 0%N -> - div_round (b *+ d) d = b. -Proof. - intros. - generalize (div_round_muln_add 0 b d H); intros. - by rewrite add0r div_round0 add0r in H0. -Qed. - -Lemma div_round_muln_add_l (a b : {poly int}) (d : nat) : - d <> 0%N -> - div_round (b *+ d + a) d = b + div_round a d. -Proof. - intros. - rewrite addrC div_round_muln_add //. - by rewrite addrC. -Qed. - -Lemma div_round_add2_ex (a b : {poly int}) (d : int) : - d <> 0 -> - { c : {poly int} | - div_round (a + b) d = div_round a d + div_round b d + c /\ - icoef_maxnorm c <= 1}. -Proof. - exists (div_round (a + b) d - (div_round a d + div_round b d)). - split. - - ring. - - rewrite /icoef_maxnorm /div_round. - apply /bigmax_leqP => i _. - rewrite !(coefD,coefN). - rewrite !coef_map_id0; try by rewrite nearest_round_int0. - rewrite !coefD. - by apply nearest_round_int_add2. -Qed. - -Definition div_round_add2_perturb (a b : {poly int}) (d : int) (dn0 : d <> 0) : {poly int} - := sval (div_round_add2_ex a b d dn0). - -Lemma div_round_add2_eq (a b : {poly int}) (d : int) (dn0 : d <> 0) - : div_round (a + b) d = div_round a d + div_round b d + - div_round_add2_perturb a b d dn0. -Proof. - rewrite /div_round_add2_perturb. - case: (div_round_add2_perturb a b d dn0). - case: div_round_add2_ex; intros; simpl. - tauto. -Qed. - -Lemma div_round_add2_eq_alt (a b : {poly int}) (d : int) (dn0 : d <> 0) : - div_round a d + div_round b d = div_round (a + b) d - div_round_add2_perturb a b d dn0. -Proof. - generalize (div_round_add2_eq a b d dn0); intros. - rewrite H. - ring. -Qed. - -Lemma div_round_add2_perturb_small (a b : {poly int}) (d : int) (dn0 : d <> 0) - : icoef_maxnorm (div_round_add2_perturb a b d dn0) <= 1. -Proof. - rewrite /div_round_add2_perturb. - case: (div_round_add2_perturb a b d dn0). - case: div_round_add2_ex; intros; simpl. - tauto. -Qed. - -Definition q_reduce (q : nat) (p : {poly int}) : {poly 'Z_q} := - map_poly (fun c => c%:~R) p. - -Lemma q_reduce_is_rmorphism (q : nat) : - rmorphism (q_reduce q). -Proof. - apply map_poly_is_rmorphism. -Qed. - -Canonical q_reduce_rmorphism (q : nat) := RMorphism (q_reduce_is_rmorphism q). - -Definition public_key {q : nat} (e s : {poly int}) (a : {poly 'Z_q}) : {poly {poly 'Z_q}} := - Poly [:: (- a * (q_reduce q s) + (q_reduce q e)); a]. - -Definition FHE_encrypt {q : nat} (p : {poly 'Z_q}) (v e0 e1 : {poly int}) (evkey : {poly {poly 'Z_q}}) := - Poly [:: (p + q_reduce q e0); q_reduce q e1] + (q_reduce q v) *: evkey. - -Definition FHE_decrypt {q : nat} (s : {poly int}) (pp : {poly {poly 'Z_q}}) := - pp.[q_reduce q s]. - -Lemma decrypt_encrypt {q : nat} (e s v e0 e1 : {poly int}) (a p : {poly 'Z_q}) : - FHE_decrypt s (FHE_encrypt p v e0 e1 (public_key e s a)) = - p + (q_reduce q e0) + q_reduce q e1 * q_reduce q s + q_reduce q v * q_reduce q e. -Proof. - rewrite /FHE_decrypt /FHE_encrypt /public_key. - rewrite hornerD hornerZ !horner_Poly /= mul0r !add0r. - ring. -Qed. - -Lemma decrypt_add {q : nat} (P Q : {poly 'Z_q}) (PP QQ : {poly {poly 'Z_q}}) (s : {poly int}) : - FHE_decrypt s PP = P -> - FHE_decrypt s QQ = Q -> - FHE_decrypt s (FHE_add PP QQ) = P + Q. -Proof. - rewrite /FHE_decrypt /FHE_add. - intros. - by rewrite hornerD H H0. -Qed. - -Lemma decrypt_mult_base {q : nat} (P Q : {poly 'Z_q}) (PP QQ : {poly {poly 'Z_q}}) (s : {poly int}) : - FHE_decrypt s PP = P -> - FHE_decrypt s QQ = Q -> - FHE_decrypt s (FHE_mult_base PP QQ) = P * Q. -Proof. - rewrite /FHE_decrypt /FHE_mult_base. - intros. - by rewrite hornerM H H0. -Qed. - -Definition key_switch_key {q p : nat} (s s2 e : {poly int}) (a : {poly 'Z_(p*q)}) : {poly {poly 'Z_(p*q)}} := - Poly [:: (-a * (q_reduce (p * q) s) + (q_reduce (p * q) e) + q_reduce (p * q) (s2 *+ p)); a]. - -Definition ev_key {q p : nat} (s e : {poly int}) (a : {poly 'Z_(p*q)}) := - key_switch_key s (exp s 2) e a. - -Definition linearize {q p : nat} (c0 c1 c2 : {poly 'Z_q}) - (evkey : {poly {poly 'Z_(p*q)}}) := - Poly [:: c0; c1] + - map_poly (fun P => q_reduce q (div_round ((zlift c2) * (zlift P)) (p%:Z))) - evkey. - -(* evkey = [:: (-a * (q_reduce (p * q) s) + (q_reduce (p * q) e) + (q_reduce (p * q) (s^2 *+ p))); a] - *) -Ltac notHyp P := - match goal with - | [ _ : P |- _ ] => fail 1 - | _ => idtac - end. - -Ltac extend_as_perturb P := - let t := type of P in - notHyp t; generalize P; - let x := (fresh "perturb_small") in - intros x. - -Ltac perturb_zlift_facts - := repeat match goal with - | [|- context [zlift_add2_perturb ?a ?b ?c]] => - extend_as_perturb (zlift_add2_perturb_small a b c) - end. - -Ltac perturb_div_round_facts - := repeat match goal with - | [|- context [div_round_add2_perturb ?a ?b ?c ?d]] => - extend_as_perturb (div_round_add2_perturb_small a b c d) - end. - -Lemma reduce_prod1 (p q : nat) (a : {poly int}) : - q_reduce (q) (a *+ p) = (q_reduce q a) *+ p. -Proof. - by rewrite rmorphMn. -Qed. - -Lemma modn_mul2 (p q r: nat) : - p %% q * r = (p * r) %[mod q]. -Proof. - by rewrite modnMml. -Qed. - -Lemma modn_mul2r (p q r: nat) : - r * (p %% q) = r * p %[mod q]. -Proof. - by rewrite modnMmr. -Qed. - -Lemma modn_prod2 (p q n : nat) : - ((n %% (p * q) * p) %% (p * q))%N = (n %% q * p)%N. -Proof. - rewrite modn_mul2 (mulnC p _). - by rewrite muln_modl. -Qed. - -Lemma modn_prod2r (p q n : nat) : - ((p * (n %% (p * q))) %% (p * q))%N = (p * (n %% q))%N. -Proof. - by rewrite modn_mul2r muln_modr. -Qed. - -Lemma reduce_prod2 (p q : nat) (a : int) : - 1 < q -> - 1 < p*q -> - nat_of_ord ((a %:~R : 'Z_(p*q))*+ p) = ((a%:~R : 'Z_q) * p)%N. -Proof. - intros. - destruct a; rewrite /intmul !Zp_mulrn mul1n /inZp /= !Zp_cast //. - by apply modn_prod2. - - rewrite modn_prod2 //. - f_equal. - rewrite !modnB; [|lia|lia|lia|lia]. - rewrite modnMl modnn !addn0. - replace (modn (modn (S n) (muln p q)) q) with - (modn (modn (S n) q) q); trivial. - rewrite modn_mod modn_dvdm //. - apply dvdn_mull. - by apply dvdnn. -Qed. - -Lemma div2_le (p q : nat) : - (p <= q/2)%N = (2 * p <= q)%N. -Proof. - move: (Nat.divmod_spec q 1 0 1 (le_refl _)) => /=. - case: (Nat.divmod q 1 0 1) => /= x u. - rewrite !Nat.add_0_r. - move=> [-> ubound]. - destruct u; lia. -Qed. - -Lemma liftc_reduce_prod2_nat (p q : nat) (a : nat) : - 1 < q -> - 1 < p*q -> - zliftc ((a %:~R : 'Z_(p*q))*+ p) = (zliftc (a%:~R : 'Z_q)) *+p. -Proof. - move=> qbig pqbig. - rewrite /zliftc. - have ->: ((a %:~R : 'Z_(p*q))*+ p <= p * q / 2) = (((a%:~R : 'Z_q)) <= q /2). - { - rewrite -mulr_natl !div2_le. - rewrite /intmul !Zp_mulrn mul1n /inZp /= !Zp_cast //. - rewrite [modn (S 0) (muln p q)]modn_small //. - rewrite [modn (S 0) q]modn_small // !mul1n. - rewrite modn_mul2 modn_prod2r. - rewrite mulnA (mulnC 2%N p) -mulnA. - rewrite leq_pmul2l //. - lia. - } - case: leqP; rewrite reduce_prod2 //; lia. -Qed. - -Lemma Nat_even_even (x : nat) : Nat.even x = ~~ odd x. -Proof. - induction x => //=. - destruct x; [lia |]. - by rewrite -IHx Nat.even_succ Nat.negb_odd. -Qed. - -Lemma Nat_odd_odd (x : nat) : Nat.odd x = odd x. -Proof. - by rewrite -Nat.negb_even Nat_even_even negbK. -Qed. - -Lemma qodd_half (q : nat) : - odd q -> - (q = q/2 + q/2 + 1)%N. -Proof. - rewrite -!Nat.div2_div => qodd. - rewrite [LHS]Nat.div2_odd Nat_odd_odd qodd /=. - lia. -Qed. - -Lemma liftc_neg0 (q : nat) (a : int) : - 1 < q -> - (a %:~R : 'Z_q) = 0 -> - zliftc ((-a)%:~R : 'Z_q) = - zliftc (a %:~R : 'Z_q). -Proof. - intros. - rewrite /zliftc. - assert (((-a) %:~R : 'Z_q) = 0). - { - apply (f_equal (fun z => opp z)) in H0. - rewrite -rmorphN /= in H0. - by rewrite H0 oppr0. - } - rewrite H0 H1 /=; lia. -Qed. - -Lemma liftc_neg_prop (q : nat) (a : 'Z_q) : - (q = q/2 + q/2 + 1)%N -> - q - a <= q/2 = ~~ (a <= q/2). -Proof. - lia. -Qed. - -Lemma liftc_neg_prop_alt (q : nat) (a : 'Z_q) : - 1 < q -> - (q - a != a)%N -> - q - a <= q/2 = ~~ (a <= q/2). -Proof. - intros. - generalize (Nat.div_mod_eq q 2); intros. - assert (2 * (q/2) <= q)%coq_nat by lia. - destruct a. - simpl. - simpl in H0. - rewrite Zp_cast // in i. - case: (boolP (q - m <= q/2)); intros. - - case (boolP (m <= q/2)); intros i0; rewrite i0; lia. - - case (boolP (m <= q/2)); intros i1; rewrite i1; try tauto. - assert ((q mod 2)%coq_nat < 2). - { - rewrite modulo_modn. - by rewrite ltn_mod. - } - lia. -Qed. - -Lemma Z_q_small {q:nat} (c : 'Z_q) : - 1 < q -> - c < q. -Proof. - move=> qbig. - case: c => m i /=. - by rewrite Zp_cast // in i. -Qed. - -Lemma liftc_neg (q : nat) (a : int) : - 1 < q -> - odd q -> - (* ~~ intdiv.dvdz q (2 * a) *) - zliftc ((-a)%:~R : 'Z_q) = - zliftc (a %:~R : 'Z_q). -Proof. - move=> qbig qodd. - rewrite /zliftc. - move: (qodd_half q qodd) => qsum. - move : (Z_q_small (a%:~R : 'Z_q) qbig) => asmall. - case: (eqVneq (a %:~R : 'Z_q) 0)=>i; [by apply liftc_neg0|]. - assert (apos: 0 < (a%:~R : 'Z_q)). - { - have anneg: 0 <= (a%:~R : 'Z_q) by lia. - by rewrite Order.NatOrder.ltn_def i anneg. - } - - have ->: (((- a)%:~R : 'Z_q) <= q / 2) = (~~ ((a%:~R : 'Z_q) <= q / 2)). - { - rewrite rmorphN /= {1 3}Zp_cast //. - rewrite modnB; [|lia|lia]. - assert (0 < (a%:~R : 'Z_q) %% q). - { - rewrite modn_small //. - } - rewrite modnn H mul1n addn0 modn_small //. - by apply liftc_neg_prop. - } - rewrite rmorphN /=. - case: leqP => /=; intros; rewrite {1 3}Zp_cast // modn_small //; lia. -Qed. - -Lemma liftc_neg_alt (q : nat) (a : int) : - 1 < q -> - (q - (a%:~R : 'Z_q) != (a%:~R : 'Z_q))%N -> - zliftc ((-a)%:~R : 'Z_q) = - zliftc (a %:~R : 'Z_q). -Proof. - move=> qbig not_ahalf. - rewrite /zliftc. - move : (Z_q_small (a%:~R : 'Z_q) qbig) => asmall. - case: (eqVneq (a %:~R : 'Z_q) 0)=>i; [by apply liftc_neg0|]. - assert (apos: 0 < (a%:~R : 'Z_q)). - { - have anneg: 0 <= (a%:~R : 'Z_q) by lia. - by rewrite Order.NatOrder.ltn_def i anneg. - } - - have ->: (((- a)%:~R : 'Z_q) <= q / 2) = (~~ ((a%:~R : 'Z_q) <= q / 2)). - { - rewrite rmorphN /= {1 3}Zp_cast //. - rewrite modnB; [|lia|lia]. - assert (0 < (a%:~R : 'Z_q) %% q). - { - rewrite modn_small //. - } - rewrite modnn H mul1n addn0 modn_small //. - by apply liftc_neg_prop_alt. - } - rewrite rmorphN /=. - case: leqP => /=; intros; rewrite {1 3}Zp_cast // modn_small //; lia. -Qed. - -Lemma int_Zp_0 {q : nat} (a : int) : - 1 < q -> - (a%:~R : 'Z_q) = 0 -> - intdiv.modz a q = 0. -Proof. - move => qbig. - rewrite Zp_int //. - move/(f_equal val). - case: a => n ; rewrite /inZp /= Zp_cast //. - - rewrite intdiv.modz_nat => ->; lia. - - rewrite intdiv.modNz_nat; try lia. - rewrite modnS. - case: (boolP (q %| n.+1)%N) => pred a0. - + suff: ((n %% q)%N = (q-1)%N) by lia. - move /dvdnP in pred. - case: pred => x xeq. - have HH: (n + q = x * q + (q-1))%N by lia. - have: ((n + q) %% q = (x * q + (q - 1)) %% q)%N. - { - by rewrite HH. - } - rewrite -modnDm -[RHS]modnDm modnn addn0 modnMl add0n modn_mod => ->. - rewrite modn_mod modn_small; lia. - + rewrite modn_small in a0; lia. - Qed. - -Lemma int_Zp_0_alt {q : nat} (a : int) : - 1 < q -> - intdiv.modz a q = 0 -> - (a%:~R : 'Z_q) = 0. -Proof. - move => qbig qmod. - rewrite Zp_int //. - destruct a. - - rewrite intdiv.modz_nat in qmod. - apply ord_inj. - rewrite /= Zp_cast //. - lia. - - rewrite intdiv.modNz_nat in qmod; try lia. - apply ord_inj. - rewrite /= Zp_cast //. - rewrite modnS. - case: (boolP (q %| n.+1)%N) => pred. - + by rewrite subn0 modnn. - + assert ((n%% q)%N = q-1)%N by lia. - rewrite H. - replace ((q-1).+1) with q by lia. - by rewrite subnn mod0n. -Qed. - -Lemma int_Zp_0_iff (q : nat) (a : int) : - 1 < q -> - intdiv.modz a q = 0 <-> - (a%:~R : 'Z_q) = 0. -Proof. - move => qbig. - split. - - by apply int_Zp_0_alt. - - by apply int_Zp_0. -Qed. - -Lemma int_Zp_eq_iff (q : nat) (a1 a2 : int) : - 1 < q -> - intdiv.modz a1 q = intdiv.modz a2 q <-> - (a1%:~R : 'Z_q) = (a2%:~R : 'Z_q). -Proof. - split; intros. - - assert (intdiv.modz (a1 - a2) q = 0). - { - by rewrite -intdiv.modzDm H0 intdiv.modzDm subrr intdiv.mod0z. - } - rewrite int_Zp_0_iff // in H1. - rewrite rmorphD rmorphN /= in H1. - apply (f_equal (fun z => z + a2%:~R)) in H1. - rewrite add0r in H1. - rewrite -H1. - ring. - - assert (((a1 - a2)%:~R : 'Z_q) = 0). - { - by rewrite rmorphD rmorphN /= H0 subrr. - } - rewrite -int_Zp_0_iff // in H1. - generalize (intdiv.divz_eq (a1 - a2) q); intros. - rewrite H1 addr0 in H2. - apply (f_equal (fun z => z + a2)) in H2. - replace (a1 - a2 + a2) with (a1) in H2 by lia. - rewrite H2 -intdiv.modzDm intdiv.modzMl add0r. - by rewrite intdiv.modz_mod. -Qed. - -Lemma liftc_neg_alt_alt (q : nat) (a : int) : - 1 < q -> - intdiv.dvdz q (2 * a) \/ - zliftc ((-a)%:~R : 'Z_q) = - zliftc (a %:~R : 'Z_q). -Proof. - move=> qbig. - case: (boolP (intdiv.dvdz q (2 * a))) => div; [by left|]. - right. - apply liftc_neg_alt => //. - move: div. - apply contraNN => eqq. - apply/intdiv.dvdzP. - exists (intdiv.divz (2%Z * a) q). - have: 2 * (a%:~R : 'Z_q) == 0. - { - move: eqq. - move/eqP/(f_equal (fun x:nat => x + (a%:~R : 'Z_q)))%nat. - have aqsmall: (a%:~R : 'Z_q) < q by apply Z_q_small. - rewrite (_:(q - (a%:~R : 'Z_q) + (a%:~R : 'Z_q))%N = q); [| lia]. - move => eqq. - suff: (a%:~R + a%:~R) == ( 0 : 'Z_q). - { - move/eqP => <-. - apply/eqP. - ring. - } - by rewrite /eq_op/= -eqq Zp_cast // modnn. - } - generalize (intdiv.divz_eq (2%Z*a) q); intros eqq2 a0. - rewrite {1}eqq2. - suff ->:(intdiv.modz (2%Z*a) q = 0). - - by rewrite addr0. - - rewrite -(intrM (Zp_ringType (Zp_trunc q)) (2%Z) a) in a0. - move /eqP in a0. - by apply int_Zp_0. -Qed. - -Lemma liftc_reduce_prod2 (p q : nat) (a : int) : - 1 < q -> - 1 < p -> - zliftc ((a %:~R : 'Z_(p*q))*+ p) = (zliftc (a%:~R : 'Z_q)) *+p. -Proof. - move=> qbig pbig. - assert (pqbig : 1 < p*q) by lia. - rewrite /zliftc. - have ->: ((a %:~R : 'Z_(p*q))*+ p <= p * q / 2) = (((a%:~R : 'Z_q)) <= q /2). - { - rewrite -mulr_natl !div2_le. - destruct a; rewrite /intmul !Zp_mulrn mul1n /inZp /= !Zp_cast //. - - rewrite [modn (S 0) (muln p q)]modn_small //. - rewrite [modn (S 0) q]modn_small // !mul1n. - rewrite modn_mul2 modn_prod2r. - rewrite mulnA (mulnC 2%N p) -mulnA. - rewrite leq_pmul2l //. - lia. - - rewrite (modn_small pqbig) (modn_small qbig) !mul1n. - assert (p < p*q). - { - replace (p) with (p*1)%N at 1 by lia. - rewrite ltn_pmul2l //; lia. - } - rewrite (modn_small H) -muln_modr. - rewrite mulnA (mulnC 2%N p) -mulnA. - rewrite leq_pmul2l; [| lia]. - assert ((((p * q - n.+1 %% (p * q)) %% (p * q)) %% q) = - ((q - n.+1 %% q) %% q))%N. - { - rewrite modn_muln_r //. - rewrite modnB; [|lia|lia]. - rewrite modnMl addn0. - rewrite modn_muln_r //. - case: ltP; intros. - - rewrite mul1n. - assert (q - n.+1 %% q < q) by lia. - by rewrite (modn_small H0). - - rewrite mul0n sub0n. - assert (n.+1 %% q = 0)%N by lia. - by rewrite H0 subn0 modnn. - } - by rewrite H0. - } - case: leqP; rewrite reduce_prod2 //; lia. -Qed. - -Lemma lift_reduce_prod2 (p q : nat) (a : {poly int}) : - 1 < q -> - 1 < p -> - zlift (q_reduce (p * q) a *+ p) = - zlift (q_reduce q a) *+p. -Proof. - intros. - rewrite /zlift -polyP => i. - rewrite coefMn !coef_map_id0 // coefMn. - rewrite /q_reduce coef_map_id0 //. - by apply liftc_reduce_prod2. -Qed. - -Lemma zlift_valid {q : nat} (c : {poly 'Z_q}) : - 1 < q -> - q_reduce q (zlift c) = c. -Proof. - intros. - rewrite /q_reduce /zlift -polyP => i. - by rewrite !coef_map_id0 // zliftc_valid. -Qed. - -Lemma linearize_prop_mult {q p : nat} (c2 : {poly 'Z_q}) - (s e : {poly int}) (a : {poly 'Z_(p*q)}) : - let c2' := q_reduce (p*q) (zlift c2) in - (map_poly (fun P => c2' * P) (ev_key s e a)).[q_reduce (p * q) s] = - c2' * (q_reduce (p*q) (exp s 2) *+ p) + c2' * (q_reduce (p * q) e). -Proof. - rewrite /ev_key /key_switch_key. - rewrite map_Poly_id0; [|by rewrite mulr0]. - rewrite horner_Poly /= mul0r add0r. - rewrite !(zlift_add2_eq,mulrDr) rmorphMn /=. - ring. -Qed. - -Lemma linearize_prop_div {q p : nat} (c2 : {poly 'Z_q}) - (s e : {poly int}) (a : {poly 'Z_(p*q)}) : - let c2' := q_reduce (p*q) (zlift c2) in - q_reduce q (div_round (zlift ((map_poly (fun P => c2' * P) (ev_key s e a)).[q_reduce (p * q) s])) p) = - q_reduce q (div_round (zlift (c2' * (q_reduce (p*q) (exp s 2) *+ p) + c2' * (q_reduce (p * q) e))) p). -Proof. - simpl. - by rewrite linearize_prop_mult. -Qed. - -Lemma q_reduce_0_div (q : nat) (qbig: 1 < q) (p : {poly int}) : - q_reduce q p = 0 -> - { e : {poly int} | p = e *+ q}. -Proof. - intros. - exists (map_poly (fun c => intdiv.divz c q) p). - apply polyP. - intros ?. - rewrite /q_reduce in H. - rewrite coefMn coef_map_id0; [|by rewrite intdiv.div0z]. - rewrite [LHS](intdiv.divz_eq _ q). - assert (intdiv.modz p`_x q = 0). - { - move: H. - rewrite -(map_poly0 [eta intr]). - move/polyP/(_ x). - rewrite !coef_map_id0 //. - rewrite coef0 /=. - move/(f_equal val)=> /= HH. - apply int_Zp_0 => //. - by apply ord_inj. - } - lia. -Qed. - -Lemma poly_Zq_muln_q {q : nat} (qbig:1 < q) (a : {poly 'Z_q}) : - a *+ q = 0. -Proof. - rewrite -polyP => i. - rewrite coefMn coef0 Zp_mulrn. - apply ord_inj => /=. - by rewrite {3}Zp_cast // modnMl. -Qed. - -Lemma q_reduce_muln_q (q : nat) (qbig:1 < q) (a : {poly int}) : - q_reduce q (a *+ q) = 0. -Proof. - by rewrite reduce_prod1 poly_Zq_muln_q. -Qed. - -Lemma zlift_red (q : nat) (p : {poly int}) : - 1 < q -> - { e : {poly int} | - zlift (q_reduce q p) = p + e *+q}. -Proof. - intros. - assert (q_reduce q (zlift (q_reduce q p) - p) = 0). - { - rewrite rmorphD rmorphN /= zlift_valid //. - ring. - } - apply q_reduce_0_div in H0 => //. - destruct H0. - exists x. - rewrite -e. - ring. -Qed. -Lemma linearize_prop_div2 {q p : nat} (qbig : 1 < q) (pbig : 1 < p) (c2 : {poly 'Z_q}) - (s e : {poly int}) (a : {poly 'Z_(p*q)}) : - let c2' := q_reduce (p*q) (zlift c2) in - { e2 : {poly int} | - q_reduce q (div_round (zlift (c2' * (q_reduce (p*q) (exp s 2) *+ p) + c2' * (q_reduce (p * q) e))) p) = - c2 * (q_reduce q (exp s 2)) + q_reduce q (div_round ((zlift c2) * e) p + e2)}. -Proof. - assert (pqbig: (1 < p * q)) by lia. - assert (pno: (Posz p <> 0)) by lia. - destruct (zlift_red (p * q) (zlift c2 * e) pqbig). - exists - ( - div_round_add2_perturb (zlift c2 * e) (x *+ (p * q)) p pno + - div_round - (zlift_add2_perturb pqbig (q_reduce (p * q) (zlift c2 * s ^+ 2) *+ p) - (q_reduce (p * q) (zlift c2 * e))) p + - div_round_add2_perturb - (zlift (c2 * q_reduce q (s ^+ 2)) *+ p + zlift c2 * e + x *+ (p * q)) - (zlift_add2_perturb pqbig (q_reduce (p * q) (zlift c2 * s ^+ 2) *+ p) - (q_reduce (p * q) (zlift c2 * e))) p pno). - rewrite (zlift_add2_eq,mulrDr) //. - rewrite div_round_add2_eq //. - rewrite !rmorphD /= -rmorphMn -rmorphM /=. - rewrite mulrnAr rmorphMn /=. - rewrite lift_reduce_prod2 //. - rewrite div_round_muln_add_l; [|lia]. - rewrite rmorphD /=. - rewrite zlift_valid // rmorphM /= zlift_valid // -!addrA. - f_equal. - rewrite -rmorphM /=. - rewrite e0. - rewrite div_round_add2_eq // !rmorphD /= -!addrA. - replace (q_reduce q (div_round (x *+ (p * q)) p)) with (0 : {poly 'Z_q}). - - by rewrite add0r. - - rewrite mulnC mulrnA. - assert (p <> 0%N) by lia. - by rewrite div_round_muln // q_reduce_muln_q. -Qed. - -Lemma absz_triang_sum {n} (a : 'I_n -> int) : - `|\sum_(j a (fintype.lift ord0 z))). - lia. -Qed. - -Lemma add_zero (a b : nat) : - b = 0%N -> - a + b <= a. -Proof. - lia. -Qed. - -Lemma delta_maxnorm_prod (a b : {poly int}) : - icoef_maxnorm (a * b) <= (size b) * icoef_maxnorm a * icoef_maxnorm b. -Proof. - rewrite /icoef_maxnorm. - apply /bigmax_leqP. - intros. - rewrite coefMr. - eapply leq_trans. - apply absz_triang_sum. - rewrite /= -mulnA. - assert (\sum_(j < i.+1) `|(a`_(i - j) * b`_j)%R| <= - \sum_(j < size b) `|(a`_(i - j) * b`_j)%R|)%N. - { - case: (boolP (i < size b)); intros. - - replace (size b) with (i.+1 + (size b - i.+1))%nat by lia. - rewrite big_split_ord /=. - assert (0 <= \sum_(i0 < size b - i.+1) `|(a`_(i - (i.+1 + i0)) * b`_(i.+1 + i0))%R|). - { - apply big_ind; lia. - } - lia. - - replace (i.+1) with (size b + (i.+1 - size b))%nat by lia. - rewrite big_split_ord /=. - apply add_zero. - under eq_big_seq. - + intros. - rewrite abszM. - rewrite [b`_ _]nth_default /=. - * rewrite muln0. - over. - * simpl. - lia. - + by rewrite sum_nat_const muln0. - } - eapply leq_trans; [apply H0 |]. - move: (@big_sum_le_const _ (index_enum (ordinal_finType (size b))) - (fun j => (absz - (mul (R:=int_Ring) (nth (zero int_Ring) a (subn i j)) - (nth (zero int_Ring) b j)))) - ((\max_(j < size a) `|a`_j|) * (\max_(j < size b) `|b`_j|))). - rewrite /index_enum /= -!enumT !size_enum_ord. - apply => p _. - rewrite abszM. - apply leq_mul; simpl. - clear H0. - - case (ltnP (i - p) (size a)) => ineq. - + eapply leq_trans; cycle 1. - by apply leq_bigmax_seq; rewrite ?mem_enum. - by rewrite (_:(nat_of_ord i - nat_of_ord p)%nat=(nat_of_ord ((Ordinal (ineq))))) ?leqnn. - + assert (a`_(i - p) = 0). - { - by apply nth_default. - } - rewrite H0 /=. - lia. - - eapply leq_trans; cycle 1. - by apply leq_bigmax_seq; rewrite ?mem_enum. - easy. -Qed. - -Lemma Rabs_mul (a b : R) : - Rabs a * Rabs b = Rabs (a * b)%R. -Proof. - by rewrite /mul /= Rabs_mult. -Qed. - -Lemma nearest_round_le (r1 r2 : R) : - (r1 <= r2)%O -> - (nearest_round r1 <= nearest_round r2)%O. -Proof. - intros. - rewrite /nearest_round /ran_round. - generalize (upi_le r1 r2 H); intros. - case : (boolP (Order.lt _ _)); intros. - - case : (boolP (Order.lt _ _)); intros; trivial. - assert (upi r1 < upi r2)%O. - { - rewrite -(ltr_int R_numDomainType). - rewrite -(ler_int R_numDomainType) in H0. - lra. - } - lia. - - case : (boolP (Order.lt _ _)); lia. -Qed. - -Lemma Rmult_leb_compat_r (r r1 r2 : R) : ((0 : R) <= r)%O -> (r1 <= r2)%O -> (r1 * r <= r2 * r)%O. -Proof. - move/RlebP=> HH1. - move/RlebP=> HH2. - by apply/RlebP/Rmult_le_compat_r. -Qed. - -Lemma nearest_round_int_le (r1 r2 d : int) : - ((0 : int) <= d)%O -> - (r1 <= r2)%O -> - (nearest_round_int r1 d <= nearest_round_int r2 d)%O. -Proof. - intros. - rewrite /nearest_round_int. - apply nearest_round_le. - apply Rmult_leb_compat_r. - - rewrite invr_ge0. - rewrite (_:((0 : R) = 0%:~R)); last by lra. - by rewrite ler_int. - - by rewrite ler_int. -Qed. - -Lemma nearest_round_pos (r : R) : - ((0 : R) <= r)%O -> - ((0 : int) <= nearest_round r)%O. -Proof. - intros. - rewrite -nearest_round_0. - by apply nearest_round_le. -Qed. - -Lemma nearest_round_sgn_pos (r : R) : - ((0 : R) <= r)%O -> - ((0 : int) <= nearest_round_sgn r)%O. -Proof. - intros. - rewrite /nearest_round_sgn. - case: (boolP (Order.lt _ _)); intros. - - assert ((0 : R) < 0)%O by lra. - by rewrite Order.POrderTheory.ltxx in H0. - - apply nearest_round_pos. - lra. -Qed. - -Lemma nearest_round_neg (r : R) : - (r <= 0 )%O -> - (nearest_round r <= 0)%O. -Proof. - intros. - rewrite -nearest_round_0. - by apply nearest_round_le. -Qed. - -Lemma nearest_round_sgn_neg (r : R) : - (r <= 0)%O -> - (nearest_round_sgn r <= 0)%O. -Proof. - intros. - rewrite /nearest_round_sgn. - case: (boolP (Order.lt _ _)); intros. - - assert (-r >= (0 : R))%O by lra. - apply nearest_round_pos in H0. - rewrite Ropp_opp. - lra. - - assert (r = 0) by lra. - rewrite H0 nearest_round_0. - lra. -Qed. - -Lemma nearest_round_int_pos (p : nat) (c : int) : - p != 0%N -> - ((0 : int) <= c)%O -> - ((0 : int) <= nearest_round_int c p)%O. -Proof. - intros. - rewrite -(nearest_round_int0 p). - apply nearest_round_le. - rewrite mul0r. - by apply div_int_pos. -Qed. - -Lemma nearest_round_int_neg (p : nat) (c : int) : - p != 0%N -> - ((0 : int) >= c)%O -> - ((0 : int) >= nearest_round_int c p)%O. -Proof. - intros. - rewrite -(nearest_round_int0 p). - apply nearest_round_le. - rewrite mul0r. - generalize div_int_neg; intros. - by apply div_int_neg'. -Qed. - -Lemma odd_mul_half (p : nat) : - odd p -> - ~(exists (c : int), - c %:~R = (1/2 : R) *+ p). -Proof. - intros ??. - destruct H0. - apply (f_equal (fun z => 2 * z)) in H0. - assert ((2 * (x %:~R) :R) <> p %:~R). - { - replace (2 * x %:~R) with ((2 * x)%:~R : R) by lra. - intros ?. - move /eqP in H1. - rewrite eqr_int in H1. - destruct x. - - assert (2 * n = p)%N by lia. - rewrite -H2 in H. - replace (2 * n)%N with (n.*2) in H by lia. - by rewrite odd_double in H. - - lia. - } - rewrite H0 mulrnAr mulrA mulr1 in H1. - by rewrite divff in H1; [|lra]. -Qed. - -Lemma odd_div_upi_not_half (p : nat) (c : int) : - odd p -> - (((upi (c%:~R / p%:~R)%R)%:~R : R) - c%:~R / p%:~R != (1 / 2 : R))%O. -Proof. - intros. - apply /eqP. - intros ?. - apply (f_equal (fun z => z *+ p)) in H0. - assert (exists (b : int), - b %:~R = (1/2 : R) *+ p). - { - exists ((upi (c%:~R / p%:~R)%R)*+p - c). - rewrite -H0. - field. - replace (zero (Ring.zmodType (ssrnum.Num.NumField.ringType R_numFieldType))) - with (((0%:R):R)) by lra. - rewrite eqr_nat. - lia. - } - by generalize (odd_mul_half p H); intros. -Qed. - -Lemma odd_div_upi_compare (p : nat) (c : int) : - odd p -> - (((upi (c%:~R / p%:~R)%R)%:~R : R) - c%:~R / p%:~R < (1 / 2 : R))%O || - (((upi (c%:~R / p%:~R)%R)%:~R : R) - c%:~R / p%:~R > (1 / 2 : R))%O. -Proof. - intros. - generalize (odd_div_upi_not_half p c H); intros. - apply (Order.TotalTheory.lt_total H0). -Qed. - -Lemma nearest_round_int_negate (p : nat) (c : int) : - odd p -> - nearest_round_int (-c) p = - nearest_round_int c p. -Proof. - intros. - rewrite /nearest_round_int. - assert ((((- c)%:~R : R)/ p%:~R)%R = - (c %:~R / p%:~R)%R). - { - lra. - } - rewrite H0. - apply nearest_round_opp_not_half. - apply /eqP. - by apply odd_div_upi_not_half. -Qed. - -Lemma opp_Ropp (r : R) : - opp r = Ropp r. -Proof. - by rewrite /opp /=. -Qed. - -Lemma nearest_round_sgn_negate (r : R) : - nearest_round_sgn (-r) = - nearest_round_sgn r. -Proof. - rewrite /nearest_round_sgn. - rewrite -!opp_Ropp. - case: (boolP (Order.lt r 0)); intros. - - case: (boolP (Order.lt _ _)); intros; try lra. - rewrite ltrNl oppr0 in p0. - generalize (Order.POrderTheory.lt_trans p0 p); intros. - by rewrite Order.POrderTheory.ltxx in H. - - case: (boolP (Order.lt _ _)); intros. - + by rewrite opprK. - + assert (r = 0%R) by lra. - by rewrite H oppr0 nearest_round_0 oppr0. -Qed. - -Lemma nearest_round_int_sgn_negate (p : nat) (c : int) : - nearest_round_int_sgn (-c) p = - nearest_round_int_sgn c p. -Proof. - rewrite /nearest_round_int_sgn. - replace ((- c)%:~R / p%:~R)%R with - (- (c%:~R / p%:~R)%R :R) by lra. - by rewrite nearest_round_sgn_negate. -Qed. - -Lemma nearest_round_int_sgn_nat (n d : nat) : - d != 0%N -> - nearest_round_int n d = nearest_round_int_sgn n d. -Proof. - rewrite /nearest_round_int /nearest_round_int_sgn /nearest_round_sgn. - case: (boolP (Order.lt _ _)); trivial. - intros. - generalize (div_nat_pos n d H); intros. - assert ((0 : R) < (0 : R))%O. - { - eapply Order.POrderTheory.le_lt_trans. - apply H0. - apply p. - } - by rewrite Order.POrderTheory.ltxx in H1. -Qed. - -Lemma nearest_round_int_sgn_odd (d : nat) (n : int) : - odd d -> - nearest_round_int n d = nearest_round_int_sgn n d. -Proof. - intros. - destruct n. - - rewrite nearest_round_int_sgn_nat; trivial. - apply odd_gt0 in H. - lia. - - rewrite NegzE. - rewrite nearest_round_int_negate //. - rewrite nearest_round_int_sgn_negate. - rewrite nearest_round_int_sgn_nat; trivial. - lia. -Qed. - -Lemma nearest_round_sgn_le' (r1 r2 : R) : - (Rabs r1 <= Rabs r2)%O -> - (Posz `|nearest_round_sgn r1| <= Posz `|nearest_round_sgn r2|)%O. -Proof. - intros. - rewrite /nearest_round_sgn. - case: (boolP (Order.lt _ _)); intros. - - rewrite Rabs_left in H; [|apply /RltP; rewrite /zero /= in p; lra]. - case: (boolP (Order.lt _ _)); intros. - + rewrite Rabs_left in H; [|apply /RltP; rewrite /zero /= in p0; lra]. - rewrite !lez0_abs. - * rewrite !opprK. - by apply nearest_round_le. - * assert ((0 : R) <= (- r2)%R)%O by lra. - generalize (nearest_round_pos (- r2) H0); lra. - * assert ((0 : R) <= (- r1)%R)%O by lra. - generalize (nearest_round_pos (- r1) H0); lra. - + rewrite Rabs_right in H; cycle 1. - * rewrite /zero /= in i. - apply Rle_ge. - apply /RleP. - lra. - * rewrite abszN !gez0_abs. - -- by apply nearest_round_le. - -- apply nearest_round_pos; lra. - -- apply nearest_round_pos. - rewrite Ropp_opp; lra. - - rewrite Rabs_right in H; cycle 1. - + move /RltP in i. - rewrite /zero /= in i. - coq_lra. - + case: (boolP (Order.lt _ _)); intros. - * rewrite Rabs_left in H; [|apply /RltP; rewrite /zero /= in p; lra]. - rewrite abszN !gez0_abs. - -- by apply nearest_round_le. - -- apply nearest_round_pos; lra. - -- apply nearest_round_pos; lra. - * rewrite Rabs_right in H; cycle 1. - -- rewrite /zero /= in i0. - apply Rle_ge. - apply /RleP. - lra. - -- rewrite !gez0_abs. - ++ by apply nearest_round_le. - ++ apply nearest_round_pos; lra. - ++ apply nearest_round_pos; lra. -Qed. - -Lemma nearest_round_sgn_le (r1 r2 : R) : - (Rabs r1 <= Rabs r2)%O -> - `|nearest_round_sgn r1| <= `|nearest_round_sgn r2|. -Proof. - intros. - apply nearest_round_sgn_le' in H. - lia. -Qed. - -Lemma Rzero_0 : - (0 : R) = IZR Z0. -Proof. - by rewrite /zero /=. -Qed. - -Lemma nearest_round_int_sgn_le (n1 n2 d : int) : - ((0 : int) < d)%O -> - `|n1| <= `|n2| -> - `|nearest_round_int_sgn n1 d| <= `|nearest_round_int_sgn n2 d|. -Proof. - intros. - assert ((0 : R) < d%:~R)%O. - { - replace (0 : R) with (((0 : int)%:~R):R) by ring. - by rewrite ltr_int. - } - assert (d%:~R != (0 : R))%O. - { - assert (d != 0) by lia. - by rewrite -(eqr_int R_numDomainType) in H2. - } - rewrite /nearest_round_int_sgn. - apply nearest_round_sgn_le. - case: (boolP ((0 : int) <= n1)%O); intros. - - rewrite Rabs_right; cycle 1. - + apply Rle_ge. - apply /RleP. - rewrite -(ler_pM2r H1) mul0r -mulrA (mulrC _ d%:~R) divff; [|apply H2]. - rewrite mulr1. - rewrite -(ler_int R_numDomainType) in p. - lra. - + case: (boolP ((0 : int) <= n2)%O); intros. - * rewrite Rabs_right; cycle 1. - -- apply Rle_ge. - apply /RleP. - rewrite -(ler_pM2r H1) mul0r -mulrA (mulrC _ d%:~R) divff; [|apply H2]. - rewrite mulr1. - rewrite -(ler_int R_numDomainType) in p0. - lra. - -- rewrite -(ler_pM2r H1) -!mulrA !(mulrC _ d%:~R) !divff; [|apply H2]. - rewrite !mulr1 ler_int. - lia. - * rewrite Rabs_left; cycle 1. - -- apply /RltP. - rewrite -(ltr_pM2r H1) -!mulrA !(mulrC _ d%:~R) !divff; [|apply H2]. - rewrite mulr1 mulr0. - rewrite -Order.TotalTheory.ltNge -(ltr_int R_numDomainType) in i. - lra. - -- rewrite -(ler_pM2r H1) Ropp_opp mulNr -!mulrA !(mulrC _ d%:~R) !divff; [|apply H2]. - rewrite !mulr1 -mulrNz ler_int. - lia. - - rewrite Rabs_left; cycle 1. - + apply /RltP. - rewrite -(ltr_pM2r H1) -!mulrA !(mulrC _ d%:~R) !divff; [|apply H2]. - rewrite mulr1 mulr0. - rewrite -Order.TotalTheory.ltNge -(ltr_int R_numDomainType) in i. - lra. - + case: (boolP ((0 : int) <= n2)%O); intros. - * rewrite Rabs_right; cycle 1. - -- apply Rle_ge. - apply /RleP. - rewrite -(ler_pM2r H1) mul0r -mulrA (mulrC _ d%:~R) divff; [|apply H2]. - rewrite mulr1. - rewrite -(ler_int R_numDomainType) in p. - lra. - -- rewrite -(ler_pM2r H1) Ropp_opp mulNr -!mulrA !(mulrC _ d%:~R) !divff; [|apply H2]. - rewrite !mulr1 -mulrNz ler_int. - lia. - * rewrite Rabs_left; cycle 1. - -- apply /RltP. - rewrite -(ltr_pM2r H1) -!mulrA !(mulrC _ d%:~R) !divff; [|apply H2]. - rewrite mulr1 mulr0. - rewrite -Order.TotalTheory.ltNge -(ltr_int R_numDomainType) in i0. - lra. - -- rewrite -(ler_pM2r H1) Ropp_opp !mulNr -!mulrA !(mulrC _ d%:~R) !divff; [|apply H2]. - rewrite !mulr1 -!mulrNz ler_int. - lia. -Qed. - -Lemma nearest_round_leq (r : R) : - (Rabs(r - (nearest_round r)%:~R) <= 1 / 2)%O. -Proof. - rewrite /nearest_round /ran_round. - case_eq (Order.lt ((upi r)%:~R -r) (1 / 2)); intros. - - rewrite Rabs_minus_sym. - rewrite Rabs_right. - + rewrite /Rminus Rplus_add Ropp_opp. - lra. - + destruct (upi_bound r). - rewrite /Rminus Rplus_add Ropp_opp. - apply Rle_ge. - apply /RleP. - move /RltP in H0. - replace (IZR Z0) with (zero R_zmodType). - * lra. - * by rewrite /zero /=. - - rewrite /Rminus Rplus_add Ropp_opp. - rewrite Rabs_right. - + lra. - + destruct (upi_bound r). - apply Rle_ge. - apply /RleP. - move /RleP in H1. - replace (IZR Z0) with (zero R_zmodType). - * lra. - * by rewrite /zero /=. -Qed. - -Lemma nearest_round_sgn_leq (r : R) : - (Rabs(r - (nearest_round_sgn r)%:~R) <= 1 / 2)%O. -Proof. - rewrite /nearest_round_sgn. - case: (boolP (Order.lt _ _)); intros. - - generalize (nearest_round_leq (- r)); intros. - rewrite -Rabs_Ropp in H. - rewrite /Rminus Rplus_add !Ropp_opp in H. - rewrite opprD !opprK in H. - rewrite /Rminus Rplus_add !Ropp_opp. - eapply Order.POrderTheory.le_trans; cycle 1. - apply H. - apply /RlebP. - right. - f_equal. - ring. - - apply nearest_round_leq. -Qed. - -Lemma nearest_round_int_leq (p : nat) (a : int) : - p != 0%N -> - `|a - (nearest_round_int a p)*+p| <= p./2. -Proof. - intros. - rewrite /nearest_round_int. - generalize (nearest_round_leq (a%:~R / p%:~R)%R); intros. - assert (Rabs ((a - nearest_round (a%:~R / p%:~R)%R *+ p)%:~R : R)<= p%:R / 2)%O. - { - rewrite -mulr_natl. - assert ((0 : R) < p%:R)%O. - { - rewrite pmulrn_rgt0; [lia|lra]. - } - rewrite -(ler_pM2l H1) in H0. - assert (Rabs (p%:R) = p%:R). - { - rewrite Rabs_right //. - apply Rle_ge. - apply /RleP. - replace (IZR Z0) with (zero R_zmodType); trivial. - lra. - } - rewrite -{1}H2 Rabs_mul mulrDr in H0. - assert (Rabs (a - nearest_round (a%:~R / p%:~R)%R *+ p)%:~R = - Rabs (p%:R * (a%:~R / p%:~R) + p%:R * (- (nearest_round (a%:~R / p%:~R)%R)%:~R)%R)%R). - { - f_equal. - field. - lra. - } - rewrite -H3 in H0. - rewrite -mulrA. - apply H0. - } - assert ((Rabs (((a - nearest_round (a%:~R / p%:~R)%R *+ p)%:~R ) : R)) *+2 <= (p%:R/2) *+2)%O. - { - rewrite lerMn2r. - apply /orP. - by right. - } - replace (p%:R / 2 *+ 2) with ((p%:R):R) in H2 by (by field). - case: (boolP ((0 : R) <= (a - nearest_round (a%:~R / p%:~R)%R *+ p)%:~R *+ 2)%O); intros. - - rewrite Rabs_right in H2. - + rewrite -rmorphMn /= in H2. - replace (p%:R) with ((p%:~R):R) in H2 by ring. - rewrite ler_int in H2. - rewrite -rmorphMn ler0z in p0. - assert ((0 : int) <= (a - nearest_round (a%:~R / p%:~R)%R *+ p))%O by lia. - assert ((a - nearest_round (a%:~R / p%:~R)%R *+ p) <= p./2)%O by lia. - by rewrite -(gez0_abs H3) in H4. - + apply Rle_ge. - apply /RleP. - replace (IZR Z0) with (zero R_zmodType); trivial. - lra. - - rewrite Rabs_left in H2. - + rewrite Ropp_opp -rmorphN -rmorphMn /= in H2. - replace (p%:R) with ((p%:~R):R) in H2 by ring. - rewrite ler_int in H2. - rewrite -Order.TotalTheory.ltNge in i. - assert (((a - nearest_round (a%:~R / p%:~R)%R *+ p)%:~R : R) < 0%:~R )%O by lra. - rewrite ltr_int in H3. - assert (-(a - nearest_round (a%:~R / p%:~R)%R *+ p) <= p./2)%O by lia. - by rewrite -(ltz0_abs H3) in H4. - + apply /RltP. - replace (IZR Z0) with (zero R_zmodType); trivial. - lra. - Qed. - -Lemma div_round_leq (p : nat) (a : {poly int}) : - p != 0%N -> - icoef_maxnorm (a - (div_round a p)*+p) <= p./2. -Proof. - rewrite /icoef_maxnorm /div_round=> pn0. - apply/bigmax_leqP=> i _. - by rewrite coefB coefMn coef_map_id0 ?nearest_round_int0 // nearest_round_int_leq. -Qed. - - -Lemma div_round_eq (p : nat) (a : {poly int}) : - p != 0%N -> - { c : {poly int} | - a = (div_round a p)*+p + c /\ - icoef_maxnorm c <= p./2}. -Proof. - intros. - exists (a - (div_round a p)*+p). - split. - - ring. - - by apply div_round_leq. -Qed. - - -Lemma icoef_maxnorm_triang (a b : {poly int}) : - icoef_maxnorm (a + b) <= icoef_maxnorm a + icoef_maxnorm b. -Proof. - rewrite /icoef_maxnorm. - apply/bigmax_leqP => i _. - eapply leq_trans; first by rewrite coefD; apply absz_triang. - apply leq_add. - - case: (boolP (i < size a)); intros. - + eapply leq_trans; cycle 1. - apply leq_bigmax_seq; rewrite ?mem_index_enum //. - by apply (@leq_trans `|a`_(Ordinal p)|). - + rewrite nth_default //. - by rewrite leqNgt i0. - - case: (boolP (i < size b)); intros. - + eapply leq_trans; cycle 1. - apply leq_bigmax_seq; rewrite ?mem_index_enum //. - by apply (@leq_trans `|b`_(Ordinal p)|). - + rewrite nth_default //. - by rewrite leqNgt i0. -Qed. - -Lemma icoef_maxnorm_neg (a : {poly int}) : - icoef_maxnorm (-a) = icoef_maxnorm a. -Proof. - rewrite /icoef_maxnorm size_opp. - apply eq_big; trivial. - intros. - by rewrite coefE abszN. -Qed. - -Lemma nearest_round_int_mulp (a p : int) : - p != 0 -> - a = nearest_round_int (p * a) p. -Proof. - intros. - rewrite /nearest_round_int. - replace ((p * a)%:~R / p%:~R)%R with (a %:~R : R). - - by rewrite nearest_round_int_val. - - field. - move: H. - apply contra_neq. - apply (@intr_inj R_numDomainType p 0). -Qed. - -Lemma nearest_round_int_le_const (p : nat) (a c : int) : - Posz p != 0 -> - (a <= c * p)%O -> - (nearest_round_int a p <= c)%O. -Proof. - intros. - rewrite (nearest_round_int_mulp c p H). - apply nearest_round_int_le; trivial. - by rewrite mulrC. -Qed. - -Lemma nearest_round_int_le_const_nat (p a c : nat) : - Posz p != 0 -> - (a <= c * p) -> - (nearest_round_int a p <= c)%O. -Proof. - intros. - by apply nearest_round_int_le_const. -Qed. - -Lemma nearest_round_int_odd_abs (p : nat) (a : int) : - odd p -> - `|nearest_round_int a p| = `|nearest_round_int `|a| p|. -Proof. - intros. - destruct a. - - by rewrite absz_nat. - - rewrite NegzE nearest_round_int_negate //. - by rewrite !abszN absz_nat. -Qed. - -Lemma nearest_round_int_leq2 (c p : nat) (a : int) : - odd p -> - `|a| <= c * p -> - `|nearest_round_int a p| <= c. -Proof. - intros. - assert (Posz p != 0) by lia. - generalize (nearest_round_int_le_const_nat _ _ _ H1 H0); intros. - rewrite nearest_round_int_odd_abs //. - assert (p != 0%N) by lia. - assert ((0 : int) <= `|a|)%O by lia. - generalize (nearest_round_int_pos p `|a| H3 H4); intros. - rewrite {1}/absz. - case E: (nearest_round_int `|a| p); lia. -Qed. - -Lemma icoef_maxnorm_div_round_leq (c p : nat) (a : {poly int}) : - odd p -> - icoef_maxnorm a <= c * p -> - icoef_maxnorm (div_round a p) <= c. -Proof. - rewrite /icoef_maxnorm /div_round. - intros oodp pmax. - apply /bigmax_leqP => i _. - rewrite coef_map_id0. - - apply nearest_round_int_leq2; trivial. - move /bigmax_leqP in pmax. - have pfle: (size (map_poly (aR:=int_Ring) (rR:=int_Ring) (nearest_round_int^~ p) a)) <= size a - by rewrite size_poly. - apply (pmax (widen_ord pfle i) isT). - - by rewrite nearest_round_int0. -Qed. - -Lemma div_round_maxnorm_le (p : nat) (a : {poly int}): - odd p -> - icoef_maxnorm (div_round a p) <= `|nearest_round_int (icoef_maxnorm a) p|. -Proof. - intros. - rewrite /icoef_maxnorm /div_round. - apply /bigmax_leqP; intros. - rewrite coef_poly. - case: (boolP (i < size a)); intros; try lia. - rewrite nearest_round_int_odd_abs //. - assert (nearest_round_int `|a`_i| p <= nearest_round_int (\max_(j < size a) `|a`_j|) p)%O. - { - apply nearest_round_int_le. - - lia. - - have pfle: ((size (map_poly (aR:=int_Ring) (rR:=int_Ring) (nearest_round_int^~ p) a)) <= size a). - by rewrite size_poly. - by apply (@leq_bigmax_seq _ (index_enum (ordinal_finType (size a))) xpredT (fun j => `|a`_j|)%O - (widen_ord pfle i)). - } - assert (p != 0%N) by lia. - assert ((0 : int) <= nearest_round_int `|a`_i| p)%O. - { - by apply nearest_round_int_pos. - } - assert ((0 : int) <= nearest_round_int (\max_(j < size a) `|a`_j|) p)%O. - { - by apply nearest_round_int_pos. - } - lia. -Qed. - -Lemma sum_bound (n B : nat) (F : nat -> nat) : - (forall j, F j <= B) -> - \sum_(j ineq. - + by rewrite nth_default. - + by apply: (@leq_bigmax_seq _ (index_enum (ordinal_finType (size a))) xpredT (fun j => `|a`_j|)%O (Ordinal ineq)). - - case/orP: (leqVgt (size b) (i - i0)) => ineq. - + by rewrite nth_default. - + by apply: (@leq_bigmax_seq _ (index_enum (ordinal_finType (size b))) xpredT (fun j => `|b`_j|)%O (Ordinal ineq)). - } - rewrite sum_nat_const card_ord mulnA in H1. - eapply leq_trans; try apply H1. - - case/orP: (leqVgt i.+1 (size a)) => ineq. - - replace (size a) with (i.+1 + (size a - i.+1))%nat by lia. - rewrite big_split_ord /=. - lia. - - rewrite (_:i.+1 = (size a + (i.+1 - size a))%N); last by lia. - rewrite big_split_ord /=. - suff ->: (\sum_(i0 < i.+1 - size a) `|(a`_(size a + i0) * b`_(i - (size a + i0)))%R|)%nat = 0%nat - by lia. - apply/eqP. - rewrite sum_nat_eq0. - apply/forall_inP => j _. - rewrite nth_default /=; lia. -Qed. - -Lemma mul_half2 (a b : nat) : - a <= b ./2 <-> - 2 * a <= b . -Proof. - lia. -Qed. - -Lemma mul_half3 (a b c : nat) : - a <= b * c./2 -> - 2 * a <= b * c. -Proof. - lia. -Qed. - -Lemma mul_half4 (a p : nat) : - odd p -> - a * p./2 <= p * (a + 1)./2. -Proof. - intros. - generalize (odd_halfK H); intros. - apply (f_equal (fun z => z.+1)) in H0. - replace (p.-1.+1) with p in H0 by lia. - rewrite -{2}H0. - case: (boolP (odd a)); intros. - - generalize (odd_halfK p0); intros. - apply (f_equal (fun z => z.+1)) in H1. - replace (a.-1.+1) with a in H1 by lia. - rewrite -{1}H1. - apply (f_equal (fun z => (z + 1)./2)) in H1. - rewrite -H1. - replace (a./2.*2.+1 + 1)./2 with (a./2 + 1)%N by lia. - lia. - - apply even_halfK in i. - rewrite -i. - replace ((a./2.*2 + 1)./2) with (a./2) by lia. - lia. -Qed. - -Lemma div_round_mul_ex (p : nat) (a b : {poly int}) : - odd p -> - { c : {poly int} | - (div_round a p) * b = - (div_round (a * b) p) + c /\ - icoef_maxnorm c <= ((size b) * icoef_maxnorm b + 2)./2}. -Proof. - intros oddp. - have pno: (p != 0%N) by lia. - destruct (div_round_eq p a pno) as [?[??]]. - destruct (div_round_eq p (a * b) pno) as [?[??]]. - move /eqP in pno. - apply (f_equal (fun z => z * b)) in H. - rewrite mulrDl in H. - rewrite {1}H mulrnAl in H1. - apply (f_equal (fun z => z - x * b)) in H1. - rewrite -!addrA subrr addr0 in H1. - apply (f_equal (fun z => div_round z p)) in H1. - rewrite div_round_muln // addrC in H1. - rewrite div_round_muln_add // addrC in H1. - exists (div_round (x0 - x * b) p). - split; trivial. - assert (pno': Posz p != 0) by lia. - apply icoef_maxnorm_div_round_leq; trivial. - rewrite mulnC. - eapply leq_trans. - apply icoef_maxnorm_triang. - rewrite icoef_maxnorm_neg. - assert (icoef_maxnorm (x * b) <= (size b * icoef_maxnorm b) * p./2). - { - generalize (icoef_maxnorm_mul b x); intros. - rewrite mulrC in H3. - eapply leq_trans. - apply H3. - apply leq_mul; lia. - } - rewrite addnC. - eapply leq_trans. - - apply leq_add. - apply H3. - apply H2. - - assert (size b * icoef_maxnorm b * p./2 + p./2 <= (size b * icoef_maxnorm b + 1) * p./2) by lia. - eapply leq_trans. - apply H4. - eapply leq_trans. - apply mul_half4; trivial. - apply leq_mul; lia. -Qed. - -Lemma nearest_round_int_modz (n1 n2: int) (p1 p2 : nat) : - 0 < p1 -> - 1 < p2 -> - (n1%:~R : 'Z_(p1 * p2)) = (n2%:~R : 'Z_(p1 * p2)) -> - ((nearest_round_int n1 p1) %:~R : 'Z_p2) = ((nearest_round_int n2 p1) %:~R : 'Z_p2). -Proof. - intros. - assert (Posz p1 <> 0) by lia. - rewrite -int_Zp_eq_iff //. - apply nearest_round_int_mod; trivial. - rewrite int_Zp_eq_iff //. - lia. -Qed. - -Lemma div_round_mod (n1 n2 : {poly int}) (p1 p2 : nat) : - 0 < p1 -> - 1 < p2 -> - q_reduce (p1 * p2) n1 = q_reduce (p1 * p2) n2 -> - q_reduce p2 (div_round n1 p1) = q_reduce p2 (div_round n2 p1). -Proof. - rewrite /q_reduce /div_round !map_polyE. - intros. - rewrite -polyP => i. - rewrite !coef_Poly !(nth_map_default 0 0); try by rewrite rmorph0. - rewrite -polyP in H1. - specialize (H1 i). - rewrite !coef_Poly !(nth_map_default 0 0) in H1; try by rewrite rmorph0. - rewrite !coef_Poly !(nth_map_default 0 0); try by rewrite rmorph0. - - apply nearest_round_int_modz; trivial. - - by rewrite nearest_round_int0. - - by rewrite nearest_round_int0. -Qed. - -Lemma lineariz_prop_div3 {q p : nat} (qbig : 1 < q) (pbig : 1 < p) (oddp : odd p) (c2 : {poly 'Z_q}) - (s e : {poly int}) (a : {poly 'Z_(p*q)}) : - let c2' := q_reduce (p*q) (zlift c2) in - { e2 : {poly 'Z_q} | - q_reduce q (div_round (zlift ((map_poly (fun P => c2' * P) (ev_key s e a)).[q_reduce (p * q) s])) p) = - (map_poly (fun P => q_reduce q (div_round ((zlift c2) * (zlift P)) (p%:Z))) - (ev_key s e a)).[q_reduce q s] + e2}. -Proof. - assert (pqbig: (1 < p * q)) by lia. - assert (pno: (Posz p <> 0)) by lia. - destruct (div_round_mul_ex p (zlift c2 * zlift a) s oddp) as [?[??]]. - generalize (div_round_add2_perturb_small (zlift c2 * zlift a * s) - (zlift c2 * zlift (- a * q_reduce (p * q) s + q_reduce (p * q) e + q_reduce (p * q) (s ^+ 2 *+ p))) p pno); intros peterb_small. - exists - (q_reduce q ( (div_round_add2_perturb (zlift c2 * zlift a * s) - (zlift c2 * zlift (- a * q_reduce (p * q) s + q_reduce (p * q) e + q_reduce (p * q) (s ^+ 2 *+ p))) p pno - - x))). - rewrite !map_Poly_id0. - + rewrite !horner_Poly /= !mul0r !add0r. - rewrite -!rmorphM -rmorphD /=. - rewrite H -(addrA _ x) (addrC x _) addrA. - rewrite div_round_add2_eq_alt. - assert (q_reduce q - (div_round - (zlift - (q_reduce (p * q) (zlift c2) * a * q_reduce (p * q) s + - q_reduce (p * q) (zlift c2) * - (- a * q_reduce (p * q) s + q_reduce (p * q) e + q_reduce (p * q) (s ^+ 2 *+ p)))) p) = - q_reduce q - (div_round - (zlift c2 * zlift a * s + - zlift c2 * zlift (- a * q_reduce (p * q) s + q_reduce (p * q) e + q_reduce (p * q) (s ^+ 2 *+ p))) p)). - { - apply div_round_mod; try lia. - rewrite rmorphD /= zlift_valid; try lia. - f_equal. - - by rewrite !rmorphM /= !zlift_valid; try lia. - - by rewrite rmorphM /= zlift_valid; try lia. - } - rewrite H1 !rmorphD !rmorphN /= -[LHS]addr0 -!addrA. - f_equal. - ring. - + by rewrite (zlift0 0) // mulr0 div_round0 rmorph0. - + by rewrite mulr0. -Qed. - -Lemma linearize_prop {q p : nat} (qbig : 1 < q) (pbig : 1 < p) (c2 : {poly 'Z_q}) (s e : {poly int}) (a : {poly 'Z_(p*q)}) : - { e2 : {poly int} | - (map_poly (fun P => q_reduce q (div_round ((zlift c2) * (zlift P)) (p%:Z))) - (ev_key s e a)).[q_reduce q s] = - c2 * (q_reduce q (exp s 2)) + q_reduce q (div_round ((zlift c2) * e) p + e2) /\ - icoef_maxnorm e2 <= 1}. -Proof. - assert (pqbig: (1 < p * q)) by lia. - assert (pno: (Posz p <> 0)) by lia. - eexists; split. - - rewrite /ev_key /key_switch_key. - rewrite map_Poly_id0. - + rewrite horner_Poly /= mul0r add0r. - rewrite !(zlift_add2_eq,mulrDr) rmorphMn /=. - rewrite div_round_add2_eq. - rewrite lift_reduce_prod2 // mulrnAr /=. - rewrite div_round_muln_add; try lia. - rewrite !rmorphD /=. - rewrite rmorphM /=. - rewrite !zlift_valid //. - rewrite !div_round_add2_eq !rmorphD /=. - admit. - + by rewrite [zlift 0]zlift0 // mulr0 div_round0 rmorph0. - - admit. -Admitted. - -Definition rescale {q1 q2 : nat} (p : {poly 'Z_(q1 * q2)}) : {poly 'Z_q2} := - q_reduce q2 (div_round (zlift p) q1%:Z). - -Definition FHE_mult {q p : nat} (P Q : {poly {poly 'Z_q}}) - (evkey : {poly {poly 'Z_(p*q)}}) := - let PQ := FHE_mult_base P Q in - linearize (PQ`_0) (PQ`_1) (PQ`_2) evkey. - -Lemma decrypt_mult {p q : nat} (P Q : {poly 'Z_q}) (PP QQ : {poly {poly 'Z_q}}) - (s e : {poly int}) (a : {poly 'Z_(p*q)}) : - FHE_decrypt s PP = P -> - FHE_decrypt s QQ = Q -> - size PP = 2%N -> - size QQ = 2%N -> - {R : {poly int} | - FHE_decrypt s (FHE_mult PP QQ (ev_key s e a)) = - P * Q + q_reduce q (div_round (zlift (PP * QQ)`_2 * e) p + R) /\ - icoef_maxnorm R <= 1 }. -Proof. -(* - intros. - rewrite -(decrypt_mult_base P Q PP QQ s) //. - rewrite /FHE_mult /linearize /FHE_mult_base /FHE_decrypt hornerD. - rewrite linearize_prop. - assert (size (PP * QQ) <= 3%N). - { - generalize (size_mul_leq PP QQ); intros. - by rewrite H1 H2 in H3. - } - replace (q_reduce q (s ^+ 2)) with - ((q_reduce q s) ^+ 2). - - assert (PP * QQ = Poly [:: (PP * QQ)`_0; (PP * QQ)`_1; (PP * QQ)`_2]). - { - replace (PP * QQ) with (take_poly 3 (PP * QQ)) at 1. - - unfold take_poly. - rewrite -polyP. - intros ?. - rewrite coef_poly coef_Poly. - case: leqP => ineq. - + by rewrite nth_default. - + by rewrite -(nth_mkseq 0 (fun i => (PP * QQ)`_i) ineq). - - rewrite take_poly_id //. - } - rewrite {5}H4 !horner_Poly /= mul0r !add0r mulrDl addrC -!addrA. - f_equal. - + by rewrite expr2 mulrA. - + by rewrite addrC addrA. - - by rewrite rmorphXn. -Qed. - *) - Admitted. - -Definition key_switch {q p : nat} (c0 c1 : {poly 'Z_q}) - (ks_key : {poly {poly 'Z_(p*q)}}) : {poly {poly 'Z_q}} := - c0%:P + map_poly (fun P => q_reduce q (div_round ((zlift c1) * (zlift P)) (p%:Z))) - ks_key. - -Definition FHE_automorphism {q p : nat} (s e : {poly int}) - (a : {poly 'Z_(p*q)}) (P : {poly {poly 'Z_q}}) (j : nat) := - key_switch (comp_poly 'X^(2*j+1) P`_0) - (comp_poly 'X^(2*j+1) P`_1) - (key_switch_key s (comp_poly 'X^(2*j+1) s) e a). - -Lemma decrypt_FHE_automorphism_base {q p : nat} (s : {poly int}) (P : {poly 'Z_q}) - (PP : {poly {poly 'Z_q}}) (j : nat) : - FHE_decrypt s PP = P -> - comp_poly 'X^(2*j+1) P = FHE_decrypt (comp_poly 'X^(2*j+1) s) - (map_poly (comp_poly 'X^(2*j+1)) PP). -Proof. - rewrite /FHE_decrypt. - intros. - replace (q_reduce q (s \Po 'X^(2 * j + 1))) with - (comp_poly 'X^(2*j+1) (q_reduce q s)). - - by rewrite horner_map /= H. - - rewrite /q_reduce map_comp_poly /=. - f_equal. - by rewrite map_polyXn. -Qed. - diff --git a/coq/FHE/encode.v b/coq/FHE/encode.v deleted file mode 100644 index 06184a80..00000000 --- a/coq/FHE/encode.v +++ /dev/null @@ -1,5286 +0,0 @@ -Require Import Reals Lra Lia List Permutation. -From mathcomp Require Import common ssreflect fintype bigop ssrnat matrix Rstruct complex seq fingroup. -From mathcomp Require Import ssralg ssrfun. -From mathcomp Require Import generic_quotient ring_quotient. -From mathcomp Require Import poly mxpoly polydiv ssrint zmodp eqtype ssrbool div. - -Import ssralg.GRing. -Require Import nth_root. - -Ltac coq_lra := lra. -From mathcomp Require Import lra. - -Set Bullet Behavior "Strict Subproofs". - - -Lemma INR0eq (n:nat) : (INR n == 0%R) = (n == 0%nat). -Proof. - apply/eqP. - case: eqP. - - by move=> ->. - - by apply not_0_INR. -Qed. - -Lemma natmul0eq (n : nat) : ((n%:R)%R == 0%R :> R) = (n == 0%nat). -Proof. - by rewrite -INRE INR0eq. -Qed. - -Lemma modulo_modn (a b : nat) : - Nat.modulo a b = modn a b. -Proof. - case: b => [| b]. - - by rewrite modn0. - - rewrite modn_def /Nat.modulo. - move: (Nat.divmod_spec a b 0 b (le_refl _)). - case: Nat.divmod => q u []. - rewrite Nat.mul_0_r Nat.sub_diag !Nat.add_0_r /= => eqq _. - rewrite (_:(edivn a b.+1) = (q, (b - u)%coq_nat)) // -(@edivn_eq b.+1); try lia. - rewrite eqq. - f_equal; lia. -Qed. - - -Local Open Scope ring_scope. - - -Section construct. - - Import Ring ComRing UnitRing Pdiv.IdomainDefs Pdiv.WeakIdomain. - - Lemma size_poly1P_w [R : Exports.ringType] (p : {poly R}) : - size p == 1%nat -> {c | c != 0 & p = c%:P}. - Proof. - move/eqP=> pC. - have def_p: p = (p`_0)%:P - by rewrite -size1_polyC ?pC. - by exists p`_0; rewrite // -polyC_eq0 -def_p -size_poly_eq0 pC. - Qed. - - Lemma eqpP_w [R : idomainType] (m n : {poly R}) : - (m %= n) -> - {c12 | (c12.1 != 0) && (c12.2 != 0) & c12.1 *: m = c12.2 *: n}. - Proof. - case: (eqVneq m 0) => [-> /andP [/dvd0pP -> _] | m_nz]. - { by exists (1, 1); rewrite ?scaler0 // oner_eq0. } - case: (eqVneq n 0) => [-> /andP [_ /dvd0pP ->] | n_nz /andP []]. - { by exists (1, 1); rewrite ?scaler0 // oner_eq0. } - rewrite !dvdp_eq; set c1 := _ ^+ _; set c2 := _ ^+ _. - set q1 := _ %/ _; set q2 := _ %/ _; move/eqP => Hq1 /eqP Hq2; - have Hc1 : c1 != 0 by rewrite expf_eq0 lead_coef_eq0 negb_and m_nz orbT. - have Hc2 : c2 != 0 by rewrite expf_eq0 lead_coef_eq0 negb_and n_nz orbT. - have def_q12: q1 * q2 = (c1 * c2)%:P. - apply: (mulIf m_nz); rewrite mulrAC mulrC -Hq1 -scalerAr -Hq2 scalerA. - by rewrite -mul_polyC. - have: q1 * q2 != 0 by rewrite def_q12 -size_poly_eq0 size_polyC mulf_neq0. - rewrite mulf_eq0; case/norP=> nz_q1 nz_q2. - have: size q2 <= 1. - have:= size_mul nz_q1 nz_q2; rewrite def_q12 size_polyC mulf_neq0 //=. - by rewrite polySpred // => ->; rewrite leq_addl. - rewrite leq_eqVlt ltnS size_poly_leq0 (negPf nz_q2) orbF. - case/size_poly1P_w=> c cn0 cqe; exists (c2, c); first by rewrite Hc2. - by rewrite Hq2 -mul_polyC -cqe. - Qed. - - Lemma Bezoutp_w [R : idomainType] (p q : {poly R}) : {u | u.1 * p + u.2 * q %= (gcdp p q)}. - Proof. - have [-> | pn0] := eqVneq p 0. - - by rewrite gcd0p; exists (0, 1); rewrite mul0r mul1r add0r eqpxx. - - have [-> | qn0] := eqVneq q 0. - + by rewrite gcdp0; exists (1, 0); rewrite mul0r mul1r addr0 eqpxx. - + pose e := egcdp p q; exists e; rewrite eqp_sym. - by case: (egcdpP pn0 qn0). - Qed. - - Lemma Bezout_coprimepP_w [R : idomainType] (p q : {poly R}) : - coprimep p q -> - {u | u.1 * p + u.2 * q %= 1}. - Proof. - rewrite -gcdp_eqp1. - case: (Bezoutp_w p q) => [[u v] Puv]. - intros g1. - by exists (u, v); apply: eqp_trans g1. - Qed. - - Lemma Bezout_eq1_coprimepP_w [F : fieldType] (p q : {poly F}) : - coprimep (R:=F) p q -> - {u : poly_ringType F * poly_ringType F | u.1 * p + u.2 * q = 1}. - Proof. - move/ Bezout_coprimepP_w. - case=> [[u v]] /=. - case/eqpP_w=> [[c1 c2]] /andP /= [c1n0 c2n0] e. - exists (c2^-1 *: (c1 *: u), c2^-1 *: (c1 *: v)); rewrite /= -!scalerAl. - by rewrite -!scalerDr e scalerA mulVf // scale1r. - Qed. - -End construct. - - Section Chinese. - -(* The chinese remainder theorem for polynomials overa a field *) -Variables F : fieldType. - - -Definition chinesep (m1 m2 r1 r2 : {poly F}) (co_m12: coprimep (R:=F) m1 m2) := - let u := sval (Bezout_eq1_coprimepP_w m1 m2 co_m12) in - r1 * m2 * u.1 + r2 * m1 * u.2. - - Lemma chinesep_prop (m1 m2 r1 r2 : {poly F}) : - coprimep (R:=F) m1 m2 -> - {e : poly_ringType F | - e %% m1 = r1 %% m1 /\ e %% m2 = r2 %% m2}. - Proof. - intros co_m12. - destruct (Bezout_eq1_coprimepP_w m1 m2 co_m12). - pose c := r2 * x.1 * m1 + r1 * x.2 * m2. - exists c. - split. - - rewrite modpD modp_mull add0r. - apply (f_equal (fun z => z %% m1)) in e. - rewrite modpD modp_mull add0r in e. - by rewrite -mulrA -modp_mul e modp_mul mulr1. - - rewrite modpD modp_mull addr0. - apply (f_equal (fun z => z %% m2)) in e. - rewrite modpD modp_mull addr0 in e. - by rewrite -mulrA -modp_mul e modp_mul mulr1. -Qed. - - Lemma all_coprimep_prod (a : {poly F}) (l : seq {poly F}) : - all (coprimep a) l -> - coprimep a (\prod_(i <- l) i). - Proof. - intros. - rewrite big_seq. - apply big_rec. - - apply coprimep1. - - intros. - move/allP/(_ _ H0): H. - by rewrite coprimepMr H1 => ->. - Qed. - - Lemma mod_mul_mod_r (a p q : {poly F}) : - (a %% (p * q))%%q = a%%q. - Proof. - generalize (divp_eq a (p * q)); intros. - apply (f_equal (fun z => z %% q)) in H. - rewrite H. - by rewrite modpD mulrA modp_mull add0r. - Qed. - - Lemma mod_mul_mod_l (a p q : {poly F}) : - (a %% (p * q))%%p = a%%p. - Proof. - rewrite mulrC. - apply mod_mul_mod_r. - Qed. - - Lemma prod_mod (a b : {poly F}) (l : seq {poly F}) : - a %% (\prod_(q <- l) q) = b %% (\prod_(q <- l) q) -> - forall p, - p \in l -> a %% p = b %% p. - Proof. - induction l. - - intros. - by rewrite in_nil in H0. - - intros. - rewrite in_cons in H0. - move /orP in H0. - rewrite !big_cons in H. - destruct H0. - + rewrite (eqP H0). - simpl in H. - apply (f_equal (fun z => z %% a0)) in H. - by rewrite !mod_mul_mod_l in H. - + apply (f_equal (fun z => z %% (\prod_(j <- l) j))) in H. - rewrite !mod_mul_mod_r in H. - specialize (IHl H). - by apply IHl. - Qed. - -Lemma chinesep_list_prop (l : seq ({poly F} * {poly F})) : - pairwise (coprimep (R := F)) (map snd l) -> - { e : {poly F} | - forall p, - p \in l -> e %% p.2 = p.1 %% p.2}. -Proof. - induction l. - - simpl. - intros. - exists 0. - intros. - by rewrite in_nil in H0. - - intros. - rewrite map_cons pairwise_cons in H. - move /andP in H. - destruct H. - specialize (IHl H0). - destruct IHl. - assert (coprimep a.2 (\prod_(q <- l) q.2)). - { - generalize (all_coprimep_prod a.2 [seq i.2 | i <- l] H); intros. - by rewrite big_map in H1. - } - destruct (chinesep_prop a.2 (\prod_(q <- l) q.2) a.1 x H1) as [? [??]]. - exists x0. - intros. - rewrite in_cons in H4. - move /orP in H4. - destruct H4. - + by rewrite (eqP H4). - + specialize (e p H4). - rewrite -e. - assert (\prod_(q <- l) q.2 = \prod_(q <- [seq i.2 | i <- l]) q). - { - by rewrite big_map. - } - rewrite H5 in H3. - generalize (prod_mod x0 x [seq i.2 | i <- l] H3 p.2); intros. - apply H6. - by apply map_f. - Qed. - -End Chinese. - - -Definition odd_nth_roots (n : nat) := - \row_(j < 2^n) (nth_root (2 * j + 1) (2 ^ (S n))). - -Definition even_nth_roots (n : nat) := - \row_(j < 2^n) (nth_root (2 * j) (2 ^ (S n))). - -Definition nth_roots_half (n : nat) := - \row_(j < 2^n) (nth_root j (2 ^ (S n))). - -Lemma unity_root_nth_root (j n : nat) : - n.+1.-unity_root (nth_root j n.+1). -Proof. - apply /unity_rootP. - by rewrite nth_root_npow /RtoC /=. -Qed. - -Lemma primitive_root_nth_root (n : nat) : - n.+1.-primitive_root (nth_root 1 n.+1). -Proof. - intros. - rewrite /primitive_root_of_unity. - apply/andP. - split; try lia. - apply /forallP; intros. - apply /eqP. - apply /unity_rootP. - case: (eqVneq x.+1 n.+1); intros. - - by rewrite e nth_root_npow /RtoC /=. - - rewrite exp_nth_root muln1. - apply nth_root_not_1. - rewrite Nat.mod_small; try lia. - destruct x; simpl in *. - lia. -Qed. - -Lemma primitive_root_nth_root_coprime (j n : nat) : - coprime j n.+1 -> - n.+1.-primitive_root ((nth_root 1 n.+1) ^+ j). -Proof. - intros. - rewrite prim_root_exp_coprime //. - apply primitive_root_nth_root. -Qed. - -Lemma primitive_root_nth_root_coprime_alt (j n : nat) : - coprime j n.+1 -> - n.+1.-primitive_root (nth_root j n.+1). -Proof. - intros. - generalize (primitive_root_nth_root_coprime j n H); intros. - by rewrite exp_nth_root muln1 in H0. -Qed. - -Lemma pow2_S (j:nat) : - { k : nat | (2^j)%nat == S k}. -Proof. - exists (2^j-1)%nat. - induction j. - - now simpl. - - rewrite /= expnS (eqP IHj); lia. -Defined. - -Lemma primitive_root_odd_nth_root (j n : nat) : - (2^(n.+1)).-primitive_root ((nth_root 1 (2^n.+1)) ^+ (2 * j + 1)). -Proof. - destruct (pow2_S (S n)). - rewrite (eqP i). - apply primitive_root_nth_root_coprime. - rewrite -(eqP i); clear i. - assert (forall j0, coprime (2 * j0 + 1) 2). - { - intros. - rewrite coprimen2. - induction j0. - * by rewrite muln0 add0n /=. - * replace (2 * j0.+1 + 1)%N with (2 + (2 * j0 + 1))%N by lia. - by rewrite oddD IHj0 /=. - } - induction n. - - by rewrite expn1 H. - - by rewrite expnS coprimeMr IHn // H. -Qed. - -Lemma primitive_root_odd_nth_root_alt (j n : nat) : - (2^(n.+1)).-primitive_root (nth_root (2 * j + 1) (2^n.+1)). -Proof. - generalize (primitive_root_odd_nth_root j n); intros. - destruct (pow2_S (S n)). - rewrite (eqP i). - by rewrite (eqP i) exp_nth_root muln1 in H. -Qed. - -Lemma pow2_nth_root_pow1 (n j : nat) : - forall (e : nat), - (nth_root j (2^n.+1)) ^+ e = 1 <-> - (e * j) mod (2^n.+1) = 0%N. -Proof. - intros. - destruct (pow2_S n.+1). - by rewrite (eqP i) exp_nth_root nth_root_1_iff. -Qed. - -Lemma pow2_nth_root_pow_eq (n j : nat) : - forall (e1 e2 : nat), - (nth_root j (2^n.+1)) ^+ e1 = - (nth_root j (2^n.+1)) ^+ e2 <-> - (e1 * j) = (e2 * j) %[mod (2^n.+1)]. -Proof. - intros. - rewrite -!modulo_modn. - apply nth_root_pow_eq. - lia. -Qed. - -Lemma mul_INR n m : - INR(n * m) = INR n * INR m. -Proof. - by rewrite mult_INR /mul /=. -Qed. - -Lemma even_nth_root_half (j n : nat) : - 0 < n -> - nth_root (2 * j) (2 * n) = nth_root j n. -Proof. - intros. - rewrite /nth_root. - apply /eqP. - rewrite eq_complex /=. - apply /andP. - assert (INR 2 != 0). - { - rewrite /zero/=. - apply/eqP. - coq_lra. - } - assert (INR 2 \is a unit). - { - by rewrite unitfE. - } - assert (INR n \is a unit). - { - rewrite unitfE INR0eq; lia. - } - assert (inv (INR 2) * (INR 2) = 1). - { - rewrite mulrC divff //. - } - split; apply /eqP; f_equal; - rewrite -![2 * PI * _ * _]mulrA; f_equal; - rewrite !mul_INR invrM // [INR 2 * _]mulrC -mulrA; f_equal; - by rewrite mulrC -mulrA H3 mulr1. -Qed. - -Lemma even_nth_root_half_pow (j n : nat) : - nth_root (2 * j) (2 ^ (S n)) = nth_root j (2^n). -Proof. - destruct (pow2_S n). - rewrite expnS (eqP i). - apply even_nth_root_half; lia. -Qed. - -Lemma lt_0_1 : - is_true (0 < 1). -Proof. - easy. -Qed. - -Lemma poly_size_eq0 {R:ringType} (p:{poly R}) : - seq.size p == 0%nat = (p == 0). -Proof. - rewrite -size_poly_leq0. - lia. -Qed. - -Definition peval_mat {n} (roots : 'rV[R[i]]_n) : 'M[R[i]]_(n,n) := - \matrix_(i < n, j < n) (exp (roots 0 i) j). - -Definition conj_mat {n1 n2} (m : 'M[R[i]]_(n1,n2)) := - map_mx conjc m. - -Definition Vscale {n} (c : R[i]) (v : 'rV[R[i]]_n) := - c *: v. - -Definition vector_sum {n} (v : 'rV[R[i]]_n) := - \sum_(j < n) (v 0 j). - -Definition inner_prod {n} (v1 v2 : 'rV[R[i]]_n) := - (v1 *m (v2^T)) 0 0. - -Definition H_inner_prod {n} (v1 v2 : 'rV[R[i]]_n) := - inner_prod v1 (conj_mat v2). - -Lemma vector_sum_scale {n} (c : R[i]) (v : 'rV[R[i]]_n) : - mul c (vector_sum v) = vector_sum (Vscale c v). -Proof. - unfold vector_sum. - unfold Vscale. - rewrite Theory.mulr_sumr. - erewrite eq_big_seq; [reflexivity |]. - simpl. - apply ssrbool.in1W; intros. - now rewrite mxE. -Qed. - -Definition ConstVector n (c : R[i]) : 'rV[R[i]]_n:= const_mx c. - -Definition RtoC (x : R):R[i] := Complex x 0. - -(* Coercion RtoC : R >-> complex. *) - -Lemma vector_sum_const {n} (c : R[i]) : - vector_sum (ConstVector n c) = mul (n%:R) c. -Proof. - rewrite /vector_sum /ConstVector. - (under eq_big_seq => - do (apply ssrbool.in1W => ?; rewrite mxE)). - rewrite big_const_ord iter_addr_0 Theory.mulr_natl//. -Qed. - -Lemma conj_transpose {n} (m : 'M[R[i]]_(n,n)) : - conj_mat (m^T) = (conj_mat m)^T. -Proof. - now rewrite map_trmx. -Qed. - -Lemma RtoCnat_eq n : RtoC (INR n) = n%:R. -Proof. - unfold RtoC. - induction n. - - now rewrite Theory.mulr0n. - - rewrite Theory.mulrSr S_INR -IHn /add/= add0r//. -Qed. - -(* testing notations *) -Definition C0': R[i] := 0. -Definition C1': R[i] := 1. -Definition Cplus' (x y : R[i]) := x + y. -Definition Cmult' (x y : R[i]) := x * y. -Definition Cexp' (x : R[i]) (n : nat) := x ^+ n. -Definition Cdiv' (x y : R[i]) := x / y. -Definition Cinv' (x : R[i]) := x^-1. - -Lemma peval_row (n : nat) : - forall n0, - row n0 (peval_mat (odd_nth_roots (S n))) = - \row_(j < 2^(S n)) (odd_nth_roots (S n) 0 n0) ^+ j. -Proof. - intros. - unfold row. - simpl. - unfold peval_mat. - apply eq_mx; intros ??. - now rewrite mxE. -Qed. - -Lemma pow_nth_root j n e : - (nth_root j (S n)) ^+ e = nth_root (e * j) (S n). -Proof. - apply exp_nth_root. -Qed. - -Lemma pow_nth_root' j n e : - n != 0%nat -> - (nth_root j n) ^+ e = nth_root (e * j) n. -Proof. - destruct n; [lia |]=>_. - apply pow_nth_root. -Qed. - -Lemma odd_nth_roots_pow (n : nat) (c : R[i]) : - c ^+ (2 ^ (S n)) = -1 -> - forall j, - (c ^+ (2 * j + 1)) ^+ (2 ^ (S n)) = -1. -Proof. - intros. - by rewrite exprAC -exprM exprM H exprD exprM expr1 expr2 mulrNN mulr1 expr1n mul1r. -Qed. - -Lemma odd_nth_roots_powk (n : nat) (c : R[i]) : - c ^+ (2 ^ (S n)) = -1 -> - forall j k, - (c ^+ ((2 * j + 1)^k)) ^+ (2 ^ (S n)) = -1. -Proof. - intros. - rewrite -exprM mulnC exprM H. - induction k. - - by rewrite expn0 expr1. - - by rewrite expnS mulnC exprM IHk exprD exprM expr1 expr2 mulrNN mulr1 expr1n mul1r. -Qed. - -Lemma odd_nth_roots_conj (n : nat) (c : R[i]) : - c ^+ (2 ^ n.+1) = -1 -> - (conjc c) ^+ (2^n.+1) = -1. -Proof. - intros. - by rewrite -rmorphXn /= H rmorphN1. -Qed. - -Lemma decode_encode_on_diag (n : nat): - let pmat := (peval_mat (odd_nth_roots (S n))) in - forall n0, - H_inner_prod (row n0 pmat) (row n0 pmat) = (2^S n)%:R. -Proof. - intros. - unfold H_inner_prod, inner_prod, pmat. - rewrite mxE. - under eq_big_seq. - - apply ssrbool.in1W; intros. - rewrite peval_row /odd_nth_roots !mxE. - destruct (pow2_S (S (S n))). - rewrite (eqP i) pow_nth_root mult_conj_root. - over. - - rewrite big_const_ord iter_addr_0//. -Qed. - -Lemma H_inner_prod_mat n (M : 'M[R[i]]_(n,n)) : - forall i j, - (M *m (conj_mat (M ^T))) i j = - H_inner_prod (row i M) (row j M). -Proof. - rewrite /H_inner_prod /inner_prod => i j. - rewrite !mxE //=. - apply eq_big_seq => ??. - rewrite !mxE//. -Qed. - -Lemma telescope_mult_bigop_aux (c : R[i]) (n : nat) : - (c - 1) * (\sum_(0 <= j < S n) (c ^+ j)) = - \sum_(0 <= j < S n) ((c^+(S j)) - (c ^+ j)). -Proof. - rewrite big_distrr. - simpl. - apply eq_big_seq; intros ??. - rewrite mulrBl. - rewrite mul1r. - f_equal. - rewrite exprSr. - now rewrite mulrC. -Qed. - -Lemma telescope_mult_bigop (c : R[i]) (n : nat) : - (c - 1) * (\sum_(0 <= j < S n) (c ^+ j)) = - c ^+ (S n) - 1. -Proof. - rewrite telescope_mult_bigop_aux. - rewrite telescope_sumr. - + now rewrite expr0. - + lia. -Qed. - -Lemma telescope_div (c : R[i]) (n : nat) : - c <> 1 -> - \sum_(0 <= j < S n) (c ^+ j) = - (c ^+ (S n) - 1) / (c - 1). -Proof. - intros. - generalize (telescope_mult_bigop c n); intros. - rewrite -H0 mulrC mulrA Cinv_l. - - now rewrite mul1r. - - unfold not. - intros. - clear H0. - replace C0 with (zero C) in H1 by reflexivity. - apply (f_equal (fun cc => add cc 1)) in H1. - by rewrite add0r -addrA (addrC _ 1) addrN addr0 in H1. -Qed. - -Lemma telescope_pow_0_nat (c : R[i]) (n : nat) : - c <> 1 -> - c ^+ (S n) = 1 -> - \sum_(0 <= j < S n) (c ^+ j) = C0. -Proof. - intros. - rewrite telescope_div; trivial. - by rewrite H0 addrN mul0r. -Qed. - -Lemma telescope_pow_0_ord (c : R[i]) (n : nat) : - c <> 1 -> - c ^+ (S n) = 1 -> - \sum_(j < S n) (c ^+ j) = C0. -Proof. - intros. - rewrite <- (telescope_pow_0_nat c n); trivial. - by rewrite /= big_mkord. -Qed. - -Lemma add_conj (c1 c2 : R[i]) : - (conjc c1) + (conjc c2) = conjc (c1 + c2). -Proof. - by rewrite rmorphD. -Qed. - -Lemma mul_conj (c1 c2 : R[i]) : - (conjc c1) * (conjc c2) = conjc (c1 * c2). -Proof. - by rewrite rmorphM. -Qed. - -Lemma exp_conj (c : R[i]) n : - conjc (c ^+ n) = (conjc c)^+n. -Proof. - by rewrite rmorphXn. -Qed. - -Lemma decode_encode_off_diag (n : nat): - let pmat := (peval_mat (odd_nth_roots (S n))) in - forall n1 n2, - n1 <> n2 -> - H_inner_prod (row n1 pmat) (row n2 pmat) = C0. -Proof. - intros. - unfold H_inner_prod, inner_prod. - unfold pmat, peval_mat. - rewrite mxE. - simpl. - destruct (pow2_S (S n)). - unfold odd_nth_roots. - generalize (telescope_pow_0_ord ((nth_root (2*n1+1) (2^S(S n))) * - (conjc (nth_root (2*n2+1) (2^S(S n))))) x); intros. - rewrite <- H0. - - rewrite <- (eqP i). - erewrite eq_big_seq; [reflexivity |]. - simpl. - apply ssrbool.in1W; intros. - rewrite !mxE exprMn_comm. - + rewrite exp_conj//. - + rewrite /comm mulrC//. - - unfold not; intros. - destruct (pow2_S (S (S n))). - rewrite (eqP i0) nth_root_conj_alt nth_root_mul in H1. - apply nth_root_1_iff in H1. - rewrite -(eqP i0) in H1. - clear H0 i i0 pmat x x0. - destruct n1 as [x xlt]; destruct n2 as [y ylt]; simpl in *. - assert (neq:x <> y). - { - intros HH. - apply H; subst. - f_equal; apply eqtype.bool_irrelevance. - } - clear H. - rewrite !modulo_modn in H1. - apply (f_equal ssrint.Posz) in H1. - revert H1. - rewrite -intdiv.modz_nat ssrint.PoszD -ssrint.subzn. - + rewrite -intdiv.modz_nat. - rewrite -intdiv.modzDm. - rewrite !addn1 intdiv.modzDl intdiv.modzNm. - rewrite !intdiv.modzDm expnSr. - destruct (@leP x y). - * rewrite -intdiv.modzDl intdiv.modz_small/=; lia. - * rewrite intdiv.modz_small/=; lia. - + lia. - - destruct (pow2_S (S (S n))). - rewrite (eqP i0) nth_root_conj_alt nth_root_mul exp_nth_root. - apply nth_root_1_iff. - rewrite -(eqP i0) -(eqP i). - clear H0 i i0 pmat x x0. - destruct n1 as [x xlt]; destruct n2 as [y ylt]; simpl in *. - assert (neq:x <> y). - { - intros HH. - apply H; subst. - f_equal; apply eqtype.bool_irrelevance. - } - clear H. - rewrite !modulo_modn. - replace (expn 2 (S (S n))) with (expn 2 (S n) * 2)%nat by (rewrite (expnS _ (S n)); lia). - rewrite -div.muln_modr -div.modnDm. - replace (2 * x)%nat with (x * 2)%nat by lia. - rewrite div.modnMDl. - replace (div.modn (2 ^ n.+1 * 2 - div.modn (2 * y + 1) (2 ^ n.+1 * 2)) 2) with - (div.modn 1 2). - + rewrite div.modnDm. - replace (1 + 1)%nat with 2%nat by lia. - rewrite div.modnn; lia. - + replace ( div.modn (2 * y + 1) (2 ^ n.+1 * 2)) with (2 * y + 1)%nat. - * rewrite div.modnB; try lia. - * rewrite div.modn_small; lia. - Qed. - -Lemma decode_encode_scalar_mx (n : nat): - let pmat := (peval_mat (odd_nth_roots (S n))) in - pmat *m (conj_mat (pmat^T)) = scalar_mx (2^S n)%:R. -Proof. - intros. - apply matrixP; intros ??. - do 2 rewrite mxE. - destruct (@eqtype.eqP _ x y); intros. - - rewrite e. - simpl. - generalize (decode_encode_on_diag n); intros. - rewrite eqtype.eq_refl mulr1n. - unfold pmat in *. - specialize (H y). - unfold H_inner_prod, inner_prod in H. - rewrite mxE in H. - rewrite <- H. - simpl. - erewrite eq_big_seq; [reflexivity |]. - apply ssrbool.in1W; intros. - now repeat rewrite mxE. - - simpl. - generalize (decode_encode_off_diag n x y n0); intros. - unfold H_inner_prod, inner_prod in H. - rewrite mxE/=/row in H. - (* I am sure there is a better way to do this *) - repeat rewrite /eqtype.eq_op/=. - destruct (@eqnP x y); simpl in *. - + elim n0. - now apply ord_inj. - + replace ((2 ^ n.+1)%:R *+ 0) with C0 by reflexivity. - rewrite <- H. - erewrite eq_big_seq; [reflexivity |]. - apply ssrbool.in1W; intros. - now repeat rewrite mxE. -Qed. - -Lemma decode_mat_encode_mat (n : nat) (cl : 'cV[R[i]]_(2^(S n))) : - let pmat := peval_mat (odd_nth_roots (S n)) in - let encmat := conj_mat (pmat^T) in - pmat *m (encmat *m cl) = (2^S n)%:R *: cl. -Proof. - simpl. - rewrite mulmxA. - generalize (decode_encode_scalar_mx n); intros. - rewrite H. - now rewrite mul_scalar_mx. -Qed. - -(* shows evaluation can be done by modified FFT of half size*) -Lemma peval_mat_prod (n : nat) : - peval_mat (odd_nth_roots (S n)) = - peval_mat (even_nth_roots (S n)) *m diag_mx (nth_roots_half (S n)). -Proof. - apply matrixP; intros ??. - unfold nth_roots_half, even_nth_roots, peval_mat. - rewrite mul_mx_diag. - repeat rewrite mxE. - destruct (pow2_S (S (S n))). - rewrite (eqP i) !pow_nth_root nth_root_mul. - f_equal. - lia. -Qed. - -(* shows enconding can be done by modified IFFT of half size*) -Lemma encode_mat_prod (n : nat) : - let pmat := peval_mat (odd_nth_roots (S n)) in - let encmat := (conj_mat (pmat^T)) in - encmat = - diag_mx (map_mx conjc (nth_roots_half (S n))) - *m - peval_mat (map_mx conjc (even_nth_roots (S n))). -Proof. - apply matrixP; intros ??. - unfold nth_roots_half, conj_mat, peval_mat, even_nth_roots. - rewrite mul_diag_mx. - repeat rewrite mxE. - destruct (pow2_S (S (S n))). - rewrite (eqP i) pow_nth_root -exp_conj mul_conj. - f_equal. - rewrite pow_nth_root nth_root_mul. - f_equal. - lia. -Qed. - -Definition vector_rev {n} {T} (v : 'rV[T]_n) := - \row_(i < n) v 0 (rev_ord i). - -Definition vector_rev_conj {n} (v : 'rV[R[i]]_n) := - forall i, - v 0 i = conjc (v 0 (rev_ord i)). - -Lemma vector_rev_conj_plus {n} (v1 v2 : 'rV[R[i]]_n) : - vector_rev_conj v1 -> - vector_rev_conj v2 -> - vector_rev_conj (map2_mx (fun (c1 c2 : R[i]) => add c1 c2) v1 v2). -Proof. - unfold vector_rev_conj; intros. - do 2 rewrite mxE. - rewrite H. - rewrite H0. - now rewrite -rmorphD. -Qed. - -Lemma vector_rev_conj_mult {n} (v1 v2 : 'rV[R[i]]_n) : - vector_rev_conj v1 -> - vector_rev_conj v2 -> - vector_rev_conj (map2_mx (fun (c1 c2 : R[i]) => mul c1 c2) v1 v2). -Proof. - unfold vector_rev_conj; intros. - do 2 rewrite mxE. - rewrite H; rewrite H0. - now rewrite -rmorphM. -Qed. - -Lemma vector_rev_conj_scale {n} (r : R) (v : 'rV[R[i]]_n) : - vector_rev_conj v -> - vector_rev_conj (Vscale (RtoC r) v). -Proof. - unfold vector_rev_conj; intros. - unfold Vscale. - rewrite mxE. - rewrite H. - rewrite mxE. - rewrite <- mul_conj. - f_equal. - unfold conjc, RtoC. - apply f_equal. - lra. -Qed. - -Lemma vector_rev_conj_const_R n (r : R) : - vector_rev_conj (ConstVector n (RtoC r)). -Proof. - unfold vector_rev_conj, ConstVector, RtoC; intros. - do 2 rewrite mxE. - unfold conjc. - apply f_equal. - lra. -Qed. - -Lemma vector_rev_conj_conj {n} (v : 'rV[R[i]]_n) : - vector_rev_conj v -> - vector_rev_conj (map_mx conjc v). -Proof. - unfold vector_rev_conj; intros. - do 2 rewrite mxE. - now rewrite H. -Qed. - -Lemma vector_rev_conj_exp {n} i (v : 'rV[R[i]]_n) : - vector_rev_conj v -> - vector_rev_conj (map_mx (fun c => exp c i) v). -Proof. - unfold vector_rev_conj; intros. - do 2 rewrite mxE. - rewrite H. - now rewrite exp_conj. -Qed. - -Lemma Cconj_im_0 (c : C) : - conjc c = c -> Im c = 0%R. -Proof. - intros. - destruct c. - move /eqP in H. - rewrite /conjc eq_complex /= in H. - move /andP in H. - destruct H. - simpl. - lra. -Qed. - -Lemma vector_rev_sum_rev {n} (v : 'rV[R[i]]_n) : - vector_rev_conj v -> - forall i, - Im ((v + vector_rev v) 0 i) = 0. -Proof. - intros. - rewrite /vector_rev !mxE H rev_ordK. - apply Cconj_im_0. - rewrite rmorphD /= conjcK addrC//. -Qed. - -Lemma vector_rev_reflect {n} (v : 'rV[R[i]]_n) i : - vector_rev v 0 i = v 0 (rev_ord i). -Proof. - rewrite mxE//. -Qed. - -Lemma vector_sum_rev {n} (v : 'rV[R[i]]_n) : - vector_sum v = vector_sum (vector_rev v). -Proof. - unfold vector_sum, vector_rev. - rewrite (reindex_inj rev_ord_inj)/=. - apply eq_big_seq, ssrbool.in1W => x. - rewrite mxE//. -Qed. - -Lemma vector_sum_add {n} (a b : 'rV[R[i]]_n) : - vector_sum (a + b) = vector_sum a + vector_sum b. -Proof. - unfold vector_sum. - cut (\sum_(j < n) (a 0 j + b 0 j) = \sum_(j < n) a 0 j + \sum_(j < n) b 0 j). - { - intros HH. - rewrite -HH/=. - apply eq_big_seq, ssrbool.in1W => x. - rewrite mxE//. - } - rewrite big_split //. -Qed. - -Lemma Im_add (a b:R[i]) : Im (a + b) = Im a + Im b. -Proof. - now destruct a; destruct b; simpl. -Qed. - -Lemma vector_sum_reals {n} (v : 'rV[R[i]]_n) : - (forall i, Im (v 0 i) = 0) -> - Im (vector_sum v) = 0. -Proof. - unfold vector_sum. - apply big_rec; simpl; trivial. - intros. - rewrite Im_add H1 H0// addr0//. -Qed. - -Lemma vector_rev_conj_sum {n} (v : 'rV[R[i]]_n) : - vector_rev_conj v -> - Im (vector_sum v) = 0%R. -Proof. - intros. - cut (Im (vector_sum v + vector_sum (vector_rev v)) = 0). - { - rewrite -vector_sum_rev. - destruct (vector_sum v); simpl. - rewrite /add /zero/=. - coq_lra. - } - rewrite -vector_sum_add vector_sum_reals//. - now apply vector_rev_sum_rev. -Qed. - -Lemma inner_product_as_sum {n} (v1 v2 : 'rV[R[i]]_n) : - inner_prod v1 v2 = vector_sum (map2_mx (fun a b => a * b) v1 v2). -Proof. - rewrite /inner_prod /mulmx/= mxE /vector_sum/=. - apply eq_big_seq, ssrbool.in1W => x. - rewrite /map2_mx /trmx !mxE//. -Qed. - -Lemma vector_rev_conj_inner {n} (v1 v2 : 'rV[R[i]]_n) : - vector_rev_conj v1 -> - vector_rev_conj v2 -> - Im (inner_prod v1 v2) = 0. -Proof. - intros. - rewrite inner_product_as_sum vector_rev_conj_sum//. - now apply vector_rev_conj_mult. -Qed. - -Lemma vector_rev_conj_odd_nth_roots (n : nat) : - vector_rev_conj (odd_nth_roots (S n)). -Proof. - unfold vector_rev_conj, odd_nth_roots. - intros. - do 2 rewrite mxE. - destruct (pow2_S (S (S n))). - rewrite (eqP i0). - rewrite nth_root_conj_alt. - f_equal. - rewrite -(eqP i0). - rewrite (expnS _ (S n)). - unfold rev_ord; simpl. - rewrite Nat.mod_small; try lia. - destruct i. - simpl. - lia. -Qed. - -Lemma mv_rev_conj_real (n1 n2 : nat) (mat : 'M[R[i]]_(n1,n2)) (cl : 'cV[R[i]]_n2) : - vector_rev_conj (cl^T) -> - (forall i, vector_rev_conj (row i mat)) -> - forall i, - Im ((mat *m cl) i 0) = 0. -Proof. - intros. - replace ((mat *m cl) i 0) with (((row i mat) *m cl) 0 0). - - generalize (vector_rev_conj_inner (row i mat) (cl^T)); intros HH. - unfold inner_prod in HH. - rewrite trmxK in HH. - now apply HH. - - repeat rewrite mxE. - apply eq_big_seq; intros ??. - now repeat rewrite mxE. -Qed. - -Lemma encode_mat_pow_odd_roots (n:nat) : - let pmat := (peval_mat (odd_nth_roots (S n))) in - forall i, - row i pmat^T = (map_mx (fun c => c ^+ i) (odd_nth_roots (S n))). -Proof. - intros. - unfold odd_nth_roots, pmat, peval_mat. - apply matrixP; intros ??. - now repeat rewrite mxE. -Qed. - -Lemma mat_encode_real {n} (cl : 'cV[R[i]]_(2^(S n))) : - let pmat := (peval_mat (odd_nth_roots (S n))) in - let encmat := (conj_mat (pmat^T)) in - vector_rev_conj (cl^T) -> - forall i, - Im ((encmat *m cl) i 0) = 0. -Proof. - intros. - apply mv_rev_conj_real; trivial. - intros. - unfold encmat, conj_mat. - rewrite <- map_row. - apply vector_rev_conj_conj; simpl. - rewrite encode_mat_pow_odd_roots. - apply vector_rev_conj_exp. - apply vector_rev_conj_odd_nth_roots. -Qed. - -Lemma Re_Im_0 (c : R[i]) : - Im c = 0 <-> c = RtoC (Re c). -Proof. - destruct c. - unfold RtoC; simpl. - split; intros. - - now rewrite H. - - inversion H. - unfold zero; simpl; coq_lra. -Qed. - -Lemma mat_encode_real_alt {n} (cl : 'cV[R[i]]_(2^(S n))) : - let pmat := (peval_mat (odd_nth_roots (S n))) in - let encmat := (conj_mat (pmat^T)) in - vector_rev_conj (cl^T) -> - encmat *m cl = map_mx (fun c => RtoC (Re c)) (encmat *m cl). -Proof. - intros. - apply matrixP => x y. - generalize (mat_encode_real cl H x) => HH. - apply Re_Im_0 in HH. - by rewrite ord1 {}HH !mxE. -Qed. - -Definition vector_rev_col {n} {T} (v : 'cV[T]_n) := - \col_(i < n) v (rev_ord i) 0. - -Program Definition vector_reflect_conj {n} (cl : 'cV[R[i]]_(2^n)) : 'cV[R[i]]_(2^(S n)) := - col_mx cl (conj_mat (vector_rev_col cl)). -Next Obligation. - intros. - rewrite expnS. - lia. -Qed. - -Lemma vector_reflect_conj_rev_conj {n} (cl : 'cV[R[i]]_(2^n)) : - vector_rev_conj (vector_reflect_conj cl)^T. -Proof. - unfold vector_rev_conj, vector_reflect_conj. - intros. - unfold eq_rect. - destruct (vector_reflect_conj_obligation_1 n). - unfold conj_mat, vector_rev_col. - repeat rewrite mxE. - destruct (splitP i); destruct (splitP (rev_ord i)); unfold rev_ord in *; simpl in *. - - destruct i; destruct j; destruct j0; simpl in *; lia. - - rewrite !mxE/= conjcK. - f_equal. - destruct j. - cut (m = 2 ^ n - k.+1)%nat. - { - intros; subst. - f_equal; apply eqtype.bool_irrelevance. - } - destruct i; simpl in *; subst; lia. - - rewrite !mxE/=. - do 2 f_equal. - destruct j. - cut (2 ^ n - k.+1 = m)%nat. - { - intros; subst. - f_equal; apply eqtype.bool_irrelevance. - } - destruct i; simpl in *; subst; lia. - - destruct i; destruct k; destruct k0; simpl in *; lia. -Qed. - -Definition CKKS_poly_encode {n} (cl : 'cV[R[i]]_(2^n)) : 'cV[R]_(2^(S n)) := - let pmat := (peval_mat (odd_nth_roots (S n))) in - let encmat := (conj_mat (pmat^T)) in - (inv (2 ^+ S n)) *: - (map_mx (fun c => Re c) (encmat *m (vector_reflect_conj cl))). - -Definition int_to_zmodp (i : int) (p : nat) : 'Z_p := i %:~R. - -Definition col_to_poly {n} (cl : 'cV[R]_n) := rVpoly (cl^T). -Definition col_to_poly2 {n} (cl : 'cV[int]_n) := rVpoly (cl^T). - - -Definition red_poly (pol : {poly int}) p' := - map_poly (fun c => int_to_zmodp c p') pol. - -From mathcomp Require Import order. - -(* 0 <= rand < 1 *) -Definition ran_round (x rand : R) := - let hi := up x in - if (Order.lt ((IZR hi) - x) rand)%R then hi else (Zminus hi 1). - -Definition nearest_round (x : R) := ran_round x (1/2). - -Definition mx_round {n m} (mat : 'M[R]_(n,m)) : 'M[int]_(n,m) := - map_mx (fun r => ssrZ.int_of_Z (nearest_round r)) mat. - -Definition CKKS_poly_encode_Z {n} (cl : 'cV[R[i]]_(2^n)) : 'cV[int]_(2^(S n)) := - mx_round (CKKS_poly_encode cl). - -Definition vector_proj_coef {n} (v1 v2 : 'rV[R[i]]_n) := - (H_inner_prod v1 v2) / (H_inner_prod v2 v2). - -(* this is multiplication for vectors mod monic p *) -Definition rv_mul_mod {n} (a b : 'rV[R]_n) (p : {poly R}) : 'rV[R]_n := - poly_rV (rVpoly(a) * rVpoly b %% p). - -(* poly rem x^n+1 *) -Definition rv_mul_mod_xn_1 {n} (a b : 'rV[R]_n) (n : nat) : 'rV[R]_n := - let prod := (rVpoly a * rVpoly b) in - poly_rV (take_poly n prod - drop_poly n prod). - -Lemma size_Xn_addC [R : ringType] (n :nat) (b : R) : - seq.size ('X^n.+1 + b%:P) = n.+2. -Proof. - rewrite size_addl size_polyXn// (leq_ltn_trans (size_polyC_leq1 b))//. -Qed. - -Lemma lead_coef_xn [R : unitRingType] (n : nat) (c : R) : - lead_coef ('X^n.+1 + c%:P) \is a unit. -Proof. - rewrite lead_coefDl. - - rewrite lead_coefXn unitr1 //. - - rewrite size_polyXn. - generalize (size_polyC_leq1 c); lia. -Qed. - -Lemma poly_rem_xn [R : idomainType] (n : nat) (c : R) (a : {poly R}) : - let p := 'X^n.+1 + polyC c in - a %% p = take_poly n.+1 a + (-c)*: (drop_poly n.+1 a %% p). -Proof. - simpl. - have H := lead_coef_xn n c. - generalize (size_Xn_addC n c); intros. - rewrite -{1}(poly_take_drop n.+1 a). - rewrite Pdiv.IdomainUnit.modpD; trivial. - f_equal. - - rewrite modp_small; trivial. - rewrite H0. - generalize (size_take_poly n.+1 a); intros. - apply H1. - - rewrite -Pdiv.IdomainUnit.modp_mul; trivial. - assert ('X^n.+1 %% ('X^n.+1 + c%:P) = -c%:P). - { - assert ('X^n.+1 = 1 * ('X^n.+1 + c%:P) -c%:P). - { - rewrite mul1r -addrA subrr addr0//. - } - rewrite -(Pdiv.IdomainUnit.modpP H H1)//. - rewrite size_Xn_addC size_opp ltnS. - rewrite (leq_trans (size_polyC_leq1 c))//. - } - rewrite H1 mulrC -Pdiv.IdomainUnit.modpZl // -mul_polyC polyCN //. -Qed. - -From mathcomp Require Import polydiv. - -Lemma Xn_add_c_monic [R : ringType] n (c: R) : - 'X^n.+1 + c%:P \is monic. -Proof. - rewrite monicE lead_coefDl. - - rewrite lead_coefXn //. - - rewrite size_polyXn. - generalize (size_polyC_leq1 c); lia. - Qed. - -Lemma poly_rem_xn_alt [R : comRingType] (n : nat) (c : R) (a : {poly R}) : - let p := 'X^n.+1 + polyC c in - Pdiv.CommonRing.rmodp a p = - take_poly n.+1 a + - (Pdiv.CommonRing.rmodp ((-c)*:(drop_poly n.+1 a)) p). -Proof. - simpl. - have H := Xn_add_c_monic n c. - generalize (size_Xn_addC n c); intros. - rewrite -{1}(poly_take_drop n.+1 a). - rewrite Pdiv.RingMonic.rmodpD; trivial. - f_equal. - - rewrite Pdiv.CommonRing.rmodp_small; trivial. - rewrite H0. - generalize (size_take_poly n.+1 a); intros. - apply H1. - - rewrite -Pdiv.RingMonic.rmodp_mulmr; trivial. - assert (Pdiv.CommonRing.rmodp ('X^n.+1) ('X^n.+1 + c%:P) = -c%:P). - { - assert ('X^n.+1 = 1 * ('X^n.+1 + c%:P) -c%:P). - { - rewrite mul1r -addrA subrr addr0//. - } - - rewrite {1}H1. - rewrite Pdiv.RingMonic.rmodp_addl_mul_small; trivial. - rewrite size_Xn_addC size_opp ltnS. - rewrite (leq_trans (size_polyC_leq1 c))//. - } - rewrite H1 mulrC; trivial. - rewrite -mul_polyC polyCN //. -Qed. - -Require Import Recdef. -Function poly_rem_xn_1 [R : ringType] (n : nat) (a : {poly R}) {measure seq.size a} : {poly R} := - let a1 := take_poly n.+1 a in - let a2 := drop_poly n.+1 a in - if a2 == 0 then a1 else - a1 - poly_rem_xn_1 n a2. -Proof. - intros. - rewrite size_drop_poly. - enough (seq.size a <> 0)%nat by lia. - intros eqq. - rewrite drop_poly_eq0 in teq. - - rewrite eq_refl// in teq. - - rewrite eqq//. -Defined. - -Arguments poly_rem_xn_1 [R] n a : rename. - -Lemma poly_rem_xn_id [R : ringType] n (a:{poly R}) : seq.size a <= n.+1 -> - poly_rem_xn_1 n a = a. -Proof. - functional induction poly_rem_xn_1 n a => slt. - - rewrite take_poly_id//. - - rewrite drop_poly_eq0// eqxx// in y. -Qed. - -Lemma poly_rem_xn_1_le [R : ringType] n (a:{poly R}) : is_true (seq.size (poly_rem_xn_1 n a) <= n.+1). -Proof. - functional induction poly_rem_xn_1 n a. - - rewrite size_take_poly//. - - rewrite (leq_trans (size_add _ _)) // size_opp geq_max IHp size_take_poly//. - Qed. - -Lemma poly_size_0 {R : ringType} : - seq.size (zero (poly_zmodType R)) = 0%nat. -Proof. - rewrite /zero /= size_polyC eqxx//. -Qed. - -Lemma drop_poly_eq0_iff {R : ringType} (m : nat) (p : {poly R}) : - seq.size p <= m <-> drop_poly m p = 0. -Proof. - split; intros. - - now apply drop_poly_eq0. - - generalize (size_drop_poly m p); intros. - rewrite H poly_size_0 in H0. - lia. - Qed. - -Lemma poly_rem_xn_1_pmod [R : idomainType] n (a : {poly R}) : - poly_rem_xn_1 n a = a %% ('X^n.+1 + 1%:P). -Proof. - functional induction poly_rem_xn_1 n a. - - move => /eqP in e. - apply drop_poly_eq0_iff in e. - rewrite take_poly_id //. - rewrite modp_small; trivial. - rewrite size_Xn_addC. - apply e. - - rewrite poly_rem_xn IHp scaleN1r//. -Qed. - -Definition equiv_xn_1 [R : ringType] (n : nat): rel {poly R} := - fun p => fun q => poly_rem_xn_1 n (p - q) == 0. - -Definition ideal_xn_1_pred [R : ringType] (n : nat) : pred (poly_zmodType R) := - fun p => poly_rem_xn_1 n p == 0. - -Lemma poly_rem_xn_1_1 [R : ringType] n : - poly_rem_xn_1 (R:=R) n 1 = 1. -Proof. - apply poly_rem_xn_id. - rewrite size_poly1//. -Qed. - -Lemma poly_rem_xn_0_0 [R : ringType] n : - poly_rem_xn_1 (R:=R) n 0 = 0. -Proof. - apply poly_rem_xn_id. - rewrite size_poly0//. -Qed. - -Lemma rmodp_monic_scale [R : ringType] (c : R) (p d : {poly R}) : - d \is monic -> - Pdiv.CommonRing.rmodp (c *: p) d = c *: (Pdiv.CommonRing.rmodp p d). -Proof. - move=> monic. - have sz: seq.size (c *: (Pdiv.CommonRing.rmodp (R:=R) p d)) < seq.size d - by rewrite (leq_ltn_trans (size_scale_leq c (Pdiv.CommonRing.rmodp (R:=R) p d)))// - Pdiv.CommonRing.ltn_rmodpN0// monic_neq0//. - - rewrite -(Pdiv.RingMonic.rmodp_addl_mul_small monic (c *: Pdiv.CommonRing.rdivp (R:=R) p d) sz). - by rewrite {1}(Pdiv.RingMonic.rdivp_eq monic p) scalerDr scalerAl. -Qed. - -Lemma rmodp_monic_opp [R : ringType] (p d : {poly R}) : - d \is monic -> - Pdiv.CommonRing.rmodp (- p) d = -(Pdiv.CommonRing.rmodp p d). -Proof. - rewrite -!scaleN1r. - by apply rmodp_monic_scale. -Qed. - -Lemma poly_rem_xn_1_pmod_alt [R : comRingType] n (a : {poly R}) : - poly_rem_xn_1 n a = Pdiv.CommonRing.rmodp a ('X^n.+1 + 1%:P). -Proof. - functional induction poly_rem_xn_1 n a. - - move => /eqP in e. - apply drop_poly_eq0_iff in e. - rewrite take_poly_id //. - rewrite Pdiv.Ring.rmodp_small; trivial. - rewrite size_Xn_addC. - apply e. - - rewrite poly_rem_xn_alt. - f_equal. - rewrite IHp. - rewrite -rmodp_monic_opp. - + f_equal. - rewrite scaleN1r //. - + apply Xn_add_c_monic. -Qed. - -Lemma poly_rem_xn_1_eq0_mul [R : comRingType] n (a b: {poly R}) : - poly_rem_xn_1 n b = 0 -> - poly_rem_xn_1 n (a * b) = 0. -Proof. - do 2 rewrite poly_rem_xn_1_pmod_alt; intros. - rewrite -Pdiv.RingMonic.rmodp_mulmr. - - rewrite H mulr0. - apply Pdiv.Ring.rmod0p. - - apply Xn_add_c_monic. - Qed. - -Lemma poly_rem_xn_1_eq0_add [R : comRingType] n (a b: {poly R}) : - poly_rem_xn_1 n a = 0 -> - poly_rem_xn_1 n b = 0 -> - poly_rem_xn_1 n (a + b) = 0. -Proof. - do 3 rewrite poly_rem_xn_1_pmod_alt; intros. - rewrite Pdiv.RingMonic.rmodpD. - - rewrite H H0 addr0 //. - - apply Xn_add_c_monic. - Qed. - -Lemma poly_rem_xn_1_eq0_opp [R : comRingType] n (a: {poly R}) : - poly_rem_xn_1 n a = 0 -> - poly_rem_xn_1 n (- a) = 0. -Proof. - replace (-a) with (-1 * a). - - apply poly_rem_xn_1_eq0_mul. - - apply Theory.mulN1r. -Qed. - -Lemma ideal_xn_1_pred_proper [R : comRingType] n : proper_ideal (ideal_xn_1_pred (R:=R) n). -Proof. - rewrite /proper_ideal /in_mem /mem/= /ideal_xn_1_pred. - split. - - rewrite poly_rem_xn_1_1. - apply oner_neq0. - - rewrite /prop_in1/= /in_mem /= => a b /eqP /poly_rem_xn_1_eq0_mul->//. -Qed. - -Lemma ideal_xn_1_pred_zmod [R : comRingType] n : zmodPred (ideal_xn_1_pred (R :=R) n). -Proof. - constructor. - - constructor; [constructor|]. - constructor. - + rewrite /in_mem //= /ideal_xn_1_pred poly_rem_xn_0_0 //. - + rewrite /in_mem //= /prop_in2 /ideal_xn_1_pred => a b. - rewrite /in_mem /mem /= => /eqP-eqq1 /eqP-eqq2. - rewrite poly_rem_xn_1_eq0_add //. - - rewrite /Pred.Exports.oppr_closed /mem /= /ideal_xn_1_pred => a. - rewrite /in_mem /= => /eqP-eqq1. - rewrite poly_rem_xn_1_eq0_opp //. -Qed. - -Definition ideal_xn_1_idealr [R : comRingType] n : idealr (ideal_xn_1_pred (R := R) n) - := MkIdeal (ideal_xn_1_pred_zmod n) (ideal_xn_1_pred_proper n). - -Definition qring_xn [R : comRingType] n - := { ideal_quot (KeyedPred (ideal_xn_1_idealr (R := R) n)) }. - -Definition foo1_add [R : comRingType] n (a b : qring_xn (R:=R) n) := a + b. - -Section polyops. - - Context {T:comRingType}. - - Definition monic_poly := {p:{poly T} | (p \is monic) && (seq.size p > 1)}. - - Import Pdiv. - Import CommonRing. - Definition princ_ideal_pred (p : {poly T}) : pred {poly T} := - fun q => rmodp q p == 0. -(* - fun q => q %% p == 0. -*) - -Lemma princ_ideal_proper (p : monic_poly) : - proper_ideal (princ_ideal_pred (val p)). -Proof. - intros. - unfold proper_ideal, princ_ideal_pred, in_mem, mem; split; simpl. - - rewrite rmodp_small. - + rewrite poly1_neq0//. - + destruct (andP (valP p)). - rewrite size_poly1 H0 //. - - intros ??. - rewrite -RingMonic.rmodp_mulmr // /in_mem/=. - move => /eqP->. - rewrite mulr0 rmod0p//. - now destruct (andP (valP p)). -Qed. - -Lemma princ_ideal_zmod (p : monic_poly) : - zmodPred (princ_ideal_pred (val p)). -Proof. - constructor. - - constructor; [constructor|]. - constructor. - + rewrite /in_mem //= /princ_ideal_pred rmod0p//. - + rewrite /in_mem //= /prop_in2 /princ_ideal_pred => a b. - rewrite /in_mem /mem /= RingMonic.rmodpD // /=. - * move => /eqP-> /eqP->. - rewrite addr0//. - * now destruct (andP (valP p)). - - rewrite /Pred.Exports.oppr_closed /mem /= /princ_ideal_pred => a. - rewrite /in_mem /= => /eqP-eqq1. - destruct (andP (valP p)). - rewrite rmodp_monic_opp // eqq1 oppr0 //. -Qed. - -Definition princ_ideal (p : monic_poly) : - idealr (princ_ideal_pred (val p)) - := MkIdeal (princ_ideal_zmod p) (princ_ideal_proper p). - -Definition qring (p : monic_poly) - := { ideal_quot (KeyedPred (princ_ideal p)) }. - -Section example. - Context (p: monic_poly). - - Definition foo_add (a b : qring p) := a + b. - Definition foo_mul (a b : qring p) := a * b. - - Local Open Scope quotient_scope. - - Definition lift (a : {poly T}) : qring p - := lift_cst (qring p) a. - - Definition proj (a : {poly T}) := (\pi_(qring p) a). - Definition proj2 (a : {poly T}) : qring p := (\pi a). - - Example something (a b : {poly T}) := a == b %[mod (qring p)]. - -End example. -(* -Require Import qpoly. -Section qpoly. - Context {T:comRingType}. -Definition cyclotomic n : {poly T} := ('X^n.+1 + 1%:P). -Definition qpoly_add n (p q : {qpoly (cyclotomic n)}) := p + q. -Definition qpoly_mul n (p q : {qpoly (cyclotomic n)}) := p * q. -Definition qpoly_inj n (p : {poly T}) := in_qpoly (cyclotomic n) p. -End qpoly. -*) - -End polyops. - -Section rmorphism. -Lemma RtoC_is_rmorphism : - rmorphism RtoC. -Proof. - constructor. - - intros ??. - unfold RtoC, add; simpl. - f_equal. - rewrite addrN //. - - split. - + intros ??. - unfold RtoC, mul; simpl. - f_equal. - * rewrite mulr0 oppr0 addr0 //. - * rewrite mulr0 mul0r addr0 //. - + unfold RtoC, one; simpl. - unfold one, real_complex_def; simpl. - f_equal. -Qed. - -Canonical RtoC_rmorphism := RMorphism RtoC_is_rmorphism. - -Lemma map_RtoC_is_rmorphism : - rmorphism (map_poly RtoC). -Proof. - apply map_poly_is_rmorphism. -Qed. - -Canonical map_RtoC_rmorphism := RMorphism map_RtoC_is_rmorphism. - -Definition peval_C (p : {poly R}) (x : C) : C := - (map_poly RtoC p).[x]. - -Lemma ev_C_is_rmorphism (x:C) : - rmorphism (fun (p : {poly R}) => peval_C p x). -Proof. - unfold peval_C. - constructor. - - intros ??. - rewrite -horner_evalE !rmorphB //. - - split. - + intros ??. - rewrite -horner_evalE !rmorphM //. - + rewrite -horner_evalE !rmorph1 //. -Qed. - -Canonical ev_C_rmorphism (x:R[i]) := RMorphism (ev_C_is_rmorphism x). - -Lemma comp_poly_is_rmorphism (p : {poly R}) : - rmorphism (fun (q : {poly R}) => comp_poly p q). -Proof. - constructor. - - intros ??. - by rewrite comp_polyB. - - split. - + intros ??. - by rewrite comp_polyM. - + by rewrite comp_polyC. -Qed. - -Canonical comp_poly_rmorphism (p : {poly R}) := RMorphism (comp_poly_is_rmorphism p). - -Lemma sum_conj (n : nat) (F : 'I_n -> R[i]) : - conjc (\sum_(i < n) (F i)) = \sum_(i - RtoC (Re c) = c. -Proof. - intros. - rewrite /RtoC. - destruct c; simpl. - apply /eqP. - rewrite eq_complex. - rewrite /=. - apply /andP. - split; trivial. - apply /eqP. - simpl in H. - by rewrite H. -Qed. - -Lemma RtoC_cnorm (c : R[i]) : - RtoC (cnorm c) = c * conjc c. -Proof. - rewrite /cnorm. - apply RtoC_Re_Im0. - destruct c. - simpl. - lra. -Qed. - -Lemma RtoC_ctrace (c : R[i]) : - RtoC (ctrace c) = c + conjc c. -Proof. - rewrite /ctrace. - apply RtoC_Re_Im0. - destruct c. - simpl. - lra. -Qed. - -Lemma RtoC_opp (r : R) : - RtoC (- r) = - RtoC r. -Proof. - rewrite /RtoC /=. - apply /eqP. - rewrite eq_complex /=. - apply /andP; split; apply /eqP. - - by rewrite /opp /=. - - lra. -Qed. - -Definition characteristic_polynomial (c : R[i]) : {poly R} := - 'X^2 + (- ctrace c) *: 'X + (cnorm c)%:P. - -Lemma size_charpoly (c : R[i]) : - size (characteristic_polynomial c) = 3%N. -Proof. - rewrite /characteristic_polynomial. - rewrite -addrA size_addl size_polyXn //. - case : (eqVneq (ctrace c) 0); intros. - - rewrite e oppr0 scale0r add0r. - rewrite size_polyC. - case : (eqVneq (cnorm c) 0); rewrite //. - - rewrite size_addl. - + rewrite size_scale. - * rewrite size_polyX //. - * by rewrite -eqr_opp opprK oppr0. - + rewrite size_scale. - * rewrite size_polyX size_polyC. - case : (eqVneq (cnorm c) 0); rewrite //. - * by rewrite -eqr_opp opprK oppr0. -Qed. - -Lemma monic_charpoly (c : R[i]) : - characteristic_polynomial c \is monic. -Proof. - apply /monicP. - rewrite /characteristic_polynomial. - rewrite -addrA. - rewrite lead_coefDl. - - apply lead_coefXn. - - rewrite size_polyXn. - case : (eqVneq (ctrace c) 0); intros. - + rewrite e oppr0 scale0r add0r. - rewrite size_polyC. - case : (eqVneq (cnorm c) 0); rewrite //. - + rewrite size_addl. - * rewrite size_scale. - -- rewrite size_polyX //. - -- by rewrite -eqr_opp opprK oppr0. - * rewrite size_scale. - -- rewrite size_polyX size_polyC. - case : (eqVneq (cnorm c) 0); rewrite //. - -- by rewrite -eqr_opp opprK oppr0. -Qed. - -Lemma characteristic_polynomial_correct (c : R[i]) : - (map_poly RtoC (characteristic_polynomial c)).[c] = 0. -Proof. - rewrite /map_poly size_charpoly /characteristic_polynomial. - rewrite horner_poly. - rewrite big_ord_recl big_ord_recl big_ord1 /= expr0 mulr1 /bump addn0 expr1 /=. - replace (addn (S 0) (S 0)) with 2%N by lia. - rewrite !coefD !coefZ !coefX !coefC !coefXn /=. - rewrite mulr0 mulr1 !addr0 !add0r. - rewrite rmorph1 RtoC_cnorm RtoC_opp RtoC_ctrace mul1r expr2 opprD mulrDl. - rewrite -addrA (addrC _ (c * c)) (addrA _ (c * c) _) -mulrDl (addrC _ c). - by rewrite addrN mul0r add0r mulrC -mulrDl addrN mul0r. -Qed. - -Lemma trace_conj (c : R[i]) : - ctrace c = ctrace (conjc c). -Proof. - by rewrite /ctrace conjcK addrC. -Qed. - -Lemma norm_conj (c : R[i]) : - cnorm c = cnorm (conjc c). -Proof. - by rewrite /cnorm conjcK mulrC. -Qed. - -Lemma cnorm_peval_C_conj (p : {poly R}) (c : C) : - cnorm (peval_C p c) = cnorm (peval_C p (conjc c)). -Proof. - by rewrite norm_conj peval_C_conj. -Qed. - -Lemma char_poly_conj (c : R[i]) : - characteristic_polynomial c = characteristic_polynomial (conjc c). -Proof. - by rewrite /characteristic_polynomial trace_conj norm_conj. -Qed. - -Lemma char_poly_conj_alt (c : R[i]) : - (map_poly RtoC (characteristic_polynomial c)).[conjc c] = 0. -Proof. - rewrite char_poly_conj. - apply characteristic_polynomial_correct. -Qed. - -Lemma ctrace_eq (c1 c2 : R[i]) : - ctrace c1 = ctrace c2 <-> c1 + conjc c1 = c2 + conjc c2. -Proof. - rewrite /ctrace. - split; intros. - - apply /eqP. - rewrite eq_complex. - apply /andP. - split. - + now apply /eqP. - + by rewrite !ctrace_correct. - - by rewrite H. -Qed. - -Lemma cnorm_eq (c1 c2 : R[i]) : - cnorm c1 = cnorm c2 <-> c1 * conjc c1 = c2 * conjc c2. -Proof. - rewrite /cnorm. - split; intros. - - apply /eqP. - rewrite eq_complex. - apply /andP. - split. - + now apply /eqP. - + by rewrite !cnorm_correct. - - by rewrite H. -Qed. - -Lemma norm_trace_eq (c1 c2 : R[i]) : - ctrace c1 = ctrace c2 /\ cnorm c1 = cnorm c2 <-> - c1 = c2 \/ c1 = conjc c2. -Proof. - split; intros. - - destruct H. - destruct c1; destruct c2. - rewrite /ctrace /= in H. - rewrite /cnorm /= in H0. - assert (Re = Re0) by lra. - rewrite H1 in H0. - assert (Im *Im = Im0 * Im0) by lra. - rewrite /conjc. - rewrite H1. - assert (Im = Im0 \/ Im = -Im0). - { - rewrite -!expr2 in H2. - move /eqP in H2. - rewrite eqf_sqr in H2. - move /orP in H2. - destruct H2. - - by move /eqP in H2; left. - - by move /eqP in H2; right. - } - destruct H3. - + by rewrite H3; left. - + by rewrite H3; right. - - destruct H; rewrite H; try easy. - by rewrite trace_conj norm_conj conjcK. -Qed. - -Lemma charpoly_eq (c1 c2 : R[i]) : - characteristic_polynomial c1 = characteristic_polynomial c2 <-> - c1 = c2 \/ c1 = conjc c2. -Proof. - split; intros. - - rewrite /characteristic_polynomial in H. - apply polyP in H. - generalize (H 0%N); intros. - generalize (H 1%N); intros. - rewrite !coefD !coefZ !coefC !coefX !coefXn /= in H0. - rewrite !coefD !coefZ !coefC !coefX !coefXn /= in H1. - rewrite !mulr0 addr0 !add0r in H0. - rewrite !addr0 !add0r !mulr1 in H1. - apply norm_trace_eq. - split; trivial. - lra. - - destruct H. - + by rewrite H. - + by rewrite H char_poly_conj conjcK. -Qed. - -Lemma poly2_expand {S:comRingType} (c1 c2 : S) : - 'X^2 - (c1 + c2)*: 'X + (c1*c2)%:P = - ('X - c1%:P) * ('X - c2%:P). -Proof. - rewrite mulrDr !mulrDl addrA. - f_equal. - - rewrite scalerDl -expr2 -addrA. - f_equal. - rewrite (mulrC 'X _) opprD -!scaleNr -!mul_polyC. - by f_equal; rewrite polyCN. - - by rewrite mulrNN -scale_polyC mul_polyC. -Qed. - -Lemma RtoCR n : RtoC n%:R = n%:R. -Proof. - unfold RtoC. - induction n. - - now rewrite mulr0n. - - rewrite mulrSr /= mulrS -IHn. - apply /eqP. - rewrite eq_complex /= addrC addr0 /=. - by apply /andP. -Qed. - -Lemma charpoly_factor (c : R[i]) : - map_poly RtoC (characteristic_polynomial c) = - ('X - c%:P) * ('X - (conjc c)%:P). -Proof. - move: (size_charpoly c). - rewrite -poly2_expand /characteristic_polynomial /ctrace /cnorm. - rewrite map_polyE=>sz. - apply/polyP=>i. - rewrite coef_Poly. - case/orP: (leqVgt 3 i)=>ibd. - - rewrite !(coefD, coefZ, coefC, coefN, coefX, coefXn). - replace (i == 2%N) with false by lia. - replace (i == 1%N) with false by lia. - replace (i == 0%N) with false by lia. - rewrite /= mulr0 mulr0n oppr0 !addr0. - rewrite nth_default //. - by rewrite size_map sz. - - rewrite (nth_map 0); [| by rewrite sz]. - rewrite !(coefD, coefZ, coefC, coefN, coefX, coefXn). - rewrite !(rmorphD, rmorphM, rmorphN) /=. - rewrite !RtoCR RtoC_ctrace mulNr. - f_equal. - case : (eqVneq i 0%N). - + by rewrite RtoC_cnorm. - + by rewrite /RtoC. - Qed. - -Lemma charpoly_irreducible (c : R[i]) : - c != conjc c -> - irreducible_poly (characteristic_polynomial c). -Proof. - intros. - assert (1 < size (characteristic_polynomial c) <= 4). - { - rewrite size_charpoly //. - } - apply (cubic_irreducible H0). - intros. - apply /negP. - unfold not; intros. - assert (root (map_poly RtoC (characteristic_polynomial c)) (RtoC x)). - { - by apply rmorph_root. - } - rewrite charpoly_factor in H2. - unfold root in H2. - rewrite hornerM hornerD mulf_eq0 in H2. - move /orP in H2. - destruct H2. - - move /eqP in H2. - rewrite hornerX hornerN hornerC in H2. - apply (f_equal (fun z => z+ c)) in H2. - rewrite -addrA add0r (addrC _ c) addrN addr0 in H2. - by rewrite -H2 /RtoC /= eq_complex /= oppr0 !eq_refl /= in H. - - move /eqP in H2. - rewrite hornerD hornerX hornerN hornerC in H2. - apply (f_equal (fun z => z+ conjc c)) in H2. - rewrite -addrA add0r (addrC _ (conjc c)) addrN addr0 in H2. - apply (f_equal (fun z => conjc z)) in H2. - rewrite conjcK /RtoC /= in H2. - by rewrite -H2 /RtoC /= eq_complex /= opprK oppr0 !eq_refl /= in H. - Qed. - -Lemma charpoly_square (c : R[i]) : - c == conjc c -> - characteristic_polynomial c = ('X - (Re c)%:P)^+2. -Proof. - intros. - move /eqP in H. - rewrite expr2 /characteristic_polynomial /ctrace /cnorm H conjcK /= -H -poly2_expand. - f_equal. - - f_equal. - rewrite -scaleNr. - f_equal. - f_equal. - destruct c. - by simpl. - - f_equal. - destruct c. - simpl. - move /eqP in H. - rewrite eq_complex /= in H. - move /andP in H. - destruct H. - move /eqP in H0. - assert (Im = 0) by lra. - rewrite H1. - lra. -Qed. - -Lemma charpoly_coprime_case1 (c1 c2 : R[i]) : - c1 != conjc c1 -> - c2 != conjc c2 -> - characteristic_polynomial c1 != characteristic_polynomial c2 -> - coprimep (characteristic_polynomial c1) (characteristic_polynomial c2). -Proof. - intros. - apply /coprimepP. - intros. - apply charpoly_irreducible in H. - apply charpoly_irreducible in H0. - specialize (H d); intros. - specialize (H0 d); intros. - rewrite -size_poly_eq1. - case : (eqVneq (size d) 1%N); trivial. - intros. - specialize (H4 i H2). - specialize (H5 i H3). - rewrite eqp_sym in H4. - generalize (eqp_trans H4 H5); intros. - rewrite eqp_monic in H6. - - by rewrite H6 in H1. - - apply monic_charpoly. - - apply monic_charpoly. -Qed. - -Lemma charpoly_coprime_case2 (c1 c2 : R[i]) : - c1 != conjc c1 -> - c2 == conjc c2 -> - characteristic_polynomial c1 != characteristic_polynomial c2 -> - coprimep (characteristic_polynomial c1) (characteristic_polynomial c2). -Proof. - intros. - apply charpoly_square in H0. - apply /coprimepP. - intros. - apply charpoly_irreducible in H. - specialize (H d); intros. - rewrite -size_poly_eq1. - case : (eqVneq (size d) 1%N); trivial. - intros. - specialize (H4 i H2). - assert (size d == size (characteristic_polynomial c2)). - { - generalize (eqp_size H4); intros. - rewrite size_charpoly. - rewrite size_charpoly in H5. - by rewrite H5. - } - rewrite Pdiv.CommonIdomain.dvdp_size_eqp in H5; trivial. - rewrite eqp_sym in H4. - generalize (eqp_trans H4 H5); intros. - rewrite eqp_monic in H6. - - by rewrite H6 in H1. - - apply monic_charpoly. - - apply monic_charpoly. -Qed. - -Lemma charpoly_coprime_case3 (c1 c2 : R[i]) : - c1 == conjc c1 -> - c2 == conjc c2 -> - characteristic_polynomial c1 != characteristic_polynomial c2 -> - coprimep (characteristic_polynomial c1) (characteristic_polynomial c2). -Proof. - intros. - apply charpoly_square in H. - apply charpoly_square in H0. - rewrite H H0. - rewrite H H0 in H1. - apply coprimep_expl. - apply coprimep_expr. - apply coprimep_XsubC2. - apply /negP. - unfold not; intros. - rewrite subr_eq0 in H2. - move /eqP in H2. - rewrite H2 in H1. - move /negP in H1. - by apply H1. -Qed. - -Lemma charpoly_comprime (c1 c2 : R[i]) : - characteristic_polynomial c1 != characteristic_polynomial c2 -> - coprimep (characteristic_polynomial c1) (characteristic_polynomial c2). -Proof. - case : (boolP (c1 == conjc c1)); intros. - - case : (boolP (c2 == conjc c2)); intros. - + by apply charpoly_coprime_case3. - + rewrite coprimep_sym. - rewrite eq_sym in H. - by apply charpoly_coprime_case2. - - case : (boolP (c2 == conjc c2)); intros. - + by apply charpoly_coprime_case2. - + by apply charpoly_coprime_case1. -Qed. - - Lemma ev_C_1 : - forall (x : C), peval_C 1 x = 1. - Proof. - apply ev_C_is_rmorphism. - Qed. - - Definition peval_C_ker_pred (x : C) : pred {poly R} := - fun p => peval_C p x == 0. - - Lemma peval_C_ker_proper (x : C) : - proper_ideal (peval_C_ker_pred x). - Proof. - split. - - by rewrite /peval_C_ker_pred /in_mem /mem /= ev_C_1 oner_neq0. - - move => a b. - rewrite /in_mem /=. - rewrite /peval_C_ker_pred. - case: (ev_C_is_rmorphism x) => _ -> /eqP->. - by rewrite mulr0. - Qed. - - Lemma peval_C_ker_zmod (x : C) : - zmodPred (peval_C_ker_pred x). - Proof. - constructor. - - constructor; [constructor|]. - constructor. - + rewrite /in_mem //= /peval_C_ker_pred /peval_C. - unfold map_poly. - rewrite poly_size_0. - rewrite (eq_poly (fun _ => 0)). - * rewrite -{2}(horner0 x). - apply /eqP. - f_equal. - apply /polyP => i /=. - rewrite coef_poly coefC /=. - f_equal. - by case: (i == 0)%nat. - * move=> i ilt. - rewrite coefC. - by case: (i == 0)%nat. - + rewrite /in_mem //= /prop_in2 /peval_C_ker_pred => a b. - rewrite /in_mem /mem /= . - generalize (raddfD (ev_C_rmorphism x)); intros. - simpl in H; rewrite H. - revert H0 H1. - move => /eqP -> /eqP->. - rewrite add0r //. - - rewrite /Pred.Exports.oppr_closed /mem /= /peval_C_ker_pred => a. - rewrite /in_mem /= => /eqP-eqq1. - generalize (raddfN (ev_C_rmorphism x) a); intros. - simpl in H. - rewrite H eqq1 oppr0 //. - Qed. - - Definition peval_C_ker_is_ideal (x:C) : - idealr (peval_C_ker_pred x) - := MkIdeal (peval_C_ker_zmod x) (peval_C_ker_proper x). - - Canonical peval_C_ker_ideal (x:C) := KeyedPred (peval_C_ker_is_ideal x). - - Definition peval_C_ker_quot_ring (x:C) - := { ideal_quot (peval_C_ker_ideal x) }. - - Local Open Scope quotient_scope. - - Definition peval_C_quot (x:C) : peval_C_ker_quot_ring x -> C - := lift_fun1 (peval_C_ker_quot_ring x) (fun p => peval_C p x). - - Lemma pi_peval_C_quot x : {mono (\pi_(peval_C_ker_quot_ring x)) : p / peval_C p x >-> peval_C_quot x p}. - Proof. - move=> p. - rewrite /peval_C_quot -eq_lock. - case: piP => a /EquivQuot.eqmodP. - rewrite /Quotient.equiv_equiv /Quotient.equiv /in_mem /mem /= /peval_C_ker_pred. - destruct (ev_C_is_rmorphism x). - rewrite base => eqq. - move=> /eqP in eqq. - apply (f_equal (fun z => z + peval_C a x)) in eqq. - by rewrite -addrA add0r (addrC _ (peval_C a x)) addrN addr0 in eqq. - Qed. - - Lemma peval_C_quot1 x : peval_C_quot x 1 = 1. - Proof. - rewrite /one /= /Quotient.one /= /one /= /locked. - destruct master_key. - rewrite pi_peval_C_quot /peval_C. - rewrite /map_poly size_poly1. - rewrite horner_poly big_ord1 expr0 mulr1. - by rewrite -horner_coef0 hornerC. - Qed. - - Lemma peval_quot_C_is_rmorphism (c:C): rmorphism (peval_C_quot c). - Proof. - split => [x|]. - - apply quotP=> y <-. - revert x. - apply quotP => x <-. - rewrite !reprK. - rewrite !pi_peval_C_quot. - rewrite /peval_C_quot -!eq_lock. - rewrite -pi_is_additive. - case: piP => y' /eqquotP. - rewrite /peval_C_ker_pred /in_mem/mem/= => /eqP. - generalize (raddfD (ev_C_rmorphism c)); intros peval_C_add. - generalize (raddfN (ev_C_rmorphism c)); intros peval_C_opp. - generalize (peval_C_add (x-y) (-y')); intros add1. - simpl in add1; rewrite add1. - specialize (peval_C_add x (-y)); simpl in peval_C_add. - rewrite peval_C_add. - generalize (peval_C_opp y); intro opp1. - simpl in opp1; rewrite opp1. - specialize (peval_C_opp y'); simpl in peval_C_opp. - rewrite peval_C_opp. - intro HH. - apply (f_equal (fun z => z + peval_C y' c)) in HH. - rewrite add0r -!addrA in HH. - rewrite (addrC _ (peval_C y' c)) addrN addr0 in HH. - by rewrite -HH. - - constructor. - + move => x. - apply quotP=> y <-. - revert x. - apply quotP => x <-. - rewrite !reprK. - rewrite !pi_peval_C_quot. - rewrite /peval_C_quot -!eq_lock. - rewrite -pi_is_multiplicative. - case: piP => y' /eqquotP. - rewrite /peval_C_ker_pred /in_mem/mem/= => /eqP. - destruct (ev_C_is_rmorphism c) as [? [??]]. - specialize (base (x * y) y'); simpl in base. - rewrite base. - specialize (m x y); simpl in m. - rewrite m. - intro HH. - apply (f_equal (fun z => z + peval_C y' c)) in HH. - rewrite add0r -!addrA in HH. - rewrite (addrC _ (peval_C y' c)) addrN addr0 in HH. - by rewrite -HH. - + by apply peval_C_quot1. - Qed. - - Lemma peval_C_quot_is_injective (c:C) : - injective (peval_C_quot c). - Proof. - intros x y. - rewrite /peval_C_quot -!eq_lock. - rewrite -{2}[x]reprK -{2}[y]reprK. - move: (repr x) (repr y) => {x} {y} x y eqq. - apply/eqquotP. - rewrite /Quotient.equiv/=. - rewrite /peval_C_ker_pred /in_mem/mem/=. - apply/eqP. - destruct (ev_C_is_rmorphism c). - specialize (base x y). - simpl in base. - by rewrite eqq addrN in base. - Qed. - - Lemma cval_decomp (c1 c2 : C) : - Im c1 != 0 -> - {a : R * R | (a.1%:C * c1 + a.2%:C)%C = c2}. - Proof. - intros. - exists ((Im c2/Im c1), (Re c2 - (Im c2 / Im c1)*Re c1)). - apply /eqP. - rewrite eq_complex. - apply /andP. - destruct c1. - destruct c2. - simpl. - rewrite !mul0r !addr0 subr0. - - split. - + apply /eqP. - rewrite addrC. - generalize (subrKA (Im0 / Im * Re) Re0 0); intros. - by rewrite !addr0 in H0. - + rewrite -mulrA. - rewrite (mulrC _ Im). - rewrite divff //. - by rewrite mulr1. - Qed. - - Lemma peval_C_decomp (c1 c2 : C) : - Im c1 != 0 -> - {p : {poly R} | peval_C p c1 = c2}. - Proof. - intros. - destruct (cval_decomp c1 c2 H). - exists (x.1 *: 'X + x.2 %:P). - rewrite /peval_C -e. - case: (eqVneq x.1 0) => [-> |]. - - rewrite scale0r add0r /map_poly. - rewrite horner_poly mul0r add0r. - case: (eqVneq x.2 0) => [-> |]. - + by rewrite size_poly0 big_ord0. - + intros. - rewrite size_polyC. - rewrite i // big_ord1 expr0 mulr1 coefC //. - - intros. - rewrite /map_poly size_addl size_scale //; rewrite size_polyX. - + rewrite horner_poly. - rewrite big_ord_recl big_ord1 /= expr0 mulr1 /bump /= addn0 expr1. - rewrite !coefD !coefZ !coefX !coefC /=. - rewrite mulr0 mulr1 addr0 add0r. - by rewrite addrC mulrC. - + rewrite size_polyC. - by case: (eqVneq x.2 0). - Qed. - - Lemma peval_C_quot_is_surjective (c c2 :C) : - Im c != 0 -> - {p: peval_C_ker_quot_ring c | peval_C_quot c p = c2}. - Proof. - intros. - destruct (peval_C_decomp c c2); trivial. - exists (\pi_(peval_C_ker_quot_ring c) x). - rewrite -e. - apply pi_peval_C_quot. - Qed. - - - Lemma peval_C_quot_is_bijective (c :C) : - Im c != 0 -> - bijective (peval_C_quot c). - Proof. - intros imn0. - pose g : R[i] -> peval_C_ker_quot_ring c := - fun c2 => sval (peval_C_quot_is_surjective c c2 imn0). - apply Bijective with (g := g). - - intros ?. - assert (peval_C_quot c (g (peval_C_quot c x)) = peval_C_quot c x). - { - rewrite /g. - destruct (peval_C_quot_is_surjective c (peval_C_quot c x) imn0). - by simpl. - } - by apply peval_C_quot_is_injective in H. - - intros ?. - rewrite /g. - destruct (peval_C_quot_is_surjective c x imn0). - by simpl. - Qed. - - -End rmorphism. - -Module matrix_ring. -Section matrixRing. - Section base_ring. - Context {T:ringType}. - - Variable (n m : nat). - - Definition MR_mul (A B : 'M[T]_(S n, S m)) := map2_mx (fun (a b : T) => a * b) A B. - Definition MR1 : 'M[T]_(S n,S m) := const_mx 1. - Definition MR0: 'M[T]_(S n,S m) := const_mx 0. - Lemma MR_mulA : associative MR_mul. - Proof. - by move=> A B C; apply/matrixP=> i j; rewrite !mxE Monoid.mulmA. - Qed. - - Lemma MR_mul1z : left_id MR1 MR_mul. - Proof. - intros ?. - unfold MR_mul, MR1. - unfold map2_mx, const_mx. - apply matrixP; intros ??. - rewrite !mxE mul1r //. - Qed. - - Lemma MR_mulz1 : right_id MR1 MR_mul. - Proof. - intros ?. - unfold MR_mul, MR1. - unfold map2_mx, const_mx. - apply matrixP; intros ??. - rewrite !mxE mulr1 //. - Qed. - - Fact mul_0MR : left_zero MR0 MR_mul. - Proof. - intros ?. - unfold MR_mul, MR1. - unfold map2_mx, const_mx. - apply matrixP; intros ??. - rewrite !mxE mul0r //. - Qed. - - Fact mul_MR0 : right_zero MR0 MR_mul. - Proof. - intros ?. - unfold MR_mul, MR1. - unfold map2_mx, const_mx. - apply matrixP; intros ??. - rewrite !mxE mulr0 //. - Qed. - - Lemma MR_mul_addr : right_distributive MR_mul (@addmx T (S n) (S m)). - Proof. - move=> A B C; apply/matrixP=> i j. - unfold MR_mul, addmx, map2_mx. - rewrite !mxE mulrDr //. - Qed. - - Lemma MR_mul_addl : left_distributive MR_mul (@addmx T (S n) (S m)). - Proof. - move=> A B C; apply/matrixP=> i j. - unfold MR_mul, addmx, map2_mx. - rewrite !mxE mulrDl //. - Qed. - - Fact MR1_neq0 : MR1 != MR0. - Proof. - unfold MR1, MR0, const_mx. - apply /eqP/matrixP. - move/(_ ord0 ord0). - rewrite !mxE. - apply/eqP/oner_neq0. - Qed. - - Definition MR_ringMixin := - RingMixin MR_mulA MR_mul1z MR_mulz1 MR_mul_addl MR_mul_addr MR1_neq0. - - Canonical MR_ringType := Eval hnf in RingType 'M[T]_(S n,S m) MR_ringMixin. - End base_ring. - - Section com_ring. - - Context {T:comRingType}. - Variable (n m : nat). - - Lemma MR_mulC : commutative (@MR_mul T n m). - Proof. - intros ??. - unfold MR_mul, map2_mx. - apply eq_mx; intros ??. - rewrite mulrC //. - Qed. - - Canonical MR_comRingType := Eval hnf in ComRingType (@MR_ringType T n m) MR_mulC. - -(* Definition MR_embed (val : 'M[T]_(S n,S m)) : MR_comRingType := val. *) - - End com_ring. - - Section unit_ring. - Context {T:unitRingType}. - Variable (n m : nat). - - Definition MR_unit : pred 'M[T]_(S n,S m) - := [pred m : 'M[T]_(S n,S m) | [forall i, [forall j, (m i j) \is a unit]]]. - - Definition MR_inv (x:'M[T]_(S n,S m)) : 'M[T]_(S n,S m) - := if MR_unit x then map_mx inv x else x. - - Lemma MR_mulVr : ({in MR_unit, left_inverse 1 MR_inv (@mul (MR_ringType _ _))}). - Proof. - move=> M. - rewrite /MR_inv. - rewrite -topredE /= => eqq1. - rewrite eqq1. - apply/matrixP => x y. - rewrite !mxE. - apply mulVr. - move: eqq1. - rewrite/MR_unit simpl_predE. - by move/forallP /(_ x)/forallP/(_ y). - Qed. - - Lemma MR_mulrV : ({in MR_unit, right_inverse 1 MR_inv (@mul (MR_ringType _ _))}). - Proof. - move=> M. - rewrite /MR_inv. - rewrite -topredE /= => eqq1. - rewrite eqq1. - apply/matrixP => x y. - rewrite !mxE. - apply mulrV. - move: eqq1. - rewrite/MR_unit simpl_predE. - by move/forallP /(_ x)/forallP/(_ y). - Qed. - - Lemma MR_unitP (x y : MR_ringType n m) : y * x = 1 /\ x * y = 1 -> MR_unit x. - Proof. - move=>[/matrixP-linv /matrixP-rinv]. - rewrite/MR_unit/simpl_predE. - apply/forallP => i. - apply/forallP => j. - apply/unitrP. - exists (y i j). - move: linv => /(_ i j). - move: rinv => /(_ i j). - by rewrite !mxE => -> ->. - Qed. - - Lemma MR_inv0id : {in [predC MR_unit], MR_inv =1 id}. - Proof. - move => x. - by rewrite -topredE /= -topredE /= /MR_inv => /Bool.negb_true_iff->. - Qed. - - Definition MR_unitRingMixin : UnitRing.mixin_of (Ring.Pack (Ring.class (MR_ringType n m))) - := @UnitRingMixin (MR_ringType _ _) MR_unit MR_inv MR_mulVr MR_mulrV MR_unitP MR_inv0id. - - Canonical MR_unitRingType := Eval hnf in UnitRingType (@MR_ringType T n m) MR_unitRingMixin. - - End unit_ring. - - End matrixRing. -End matrix_ring. - -Section eval_vectors. - - Import matrix_ring. - - Context {n} (vals : 'rV[R[i]]_(n.+1)). - - Definition mx_eval (p : {poly R}) : MR_comRingType 0 n := - (map_mx (fun x => (map_poly RtoC p).[x]) vals). - - Lemma mx_evalC c : mx_eval c%:P = const_mx (RtoC c). - Proof. - apply matrixP=> a b. - rewrite !mxE. - by rewrite (map_polyC RtoC_rmorphism) /= hornerC. - Qed. - - Lemma mx_eval_is_rmorphism : - rmorphism (fun p => mx_eval p). - Proof. - constructor. - - move=> x y. - apply matrixP=> a b. - rewrite !mxE. - by apply ev_C_is_rmorphism. - - split. - + move=> x y. - apply matrixP=> a b. - rewrite !mxE. - by apply ev_C_is_rmorphism. - + by rewrite mx_evalC. - Qed. - - Canonical mx_eval_rmorphism - : {rmorphism poly_ringType R_ringType -> MR_comRingType 0 n} - := RMorphism (mx_eval_is_rmorphism). - - Lemma mx_eval_1 : - mx_eval 1 = 1. - Proof. - apply mx_eval_is_rmorphism. - Qed. - - Definition mx_eval_ker_pred : pred {poly R} := - fun p => mx_eval p == 0. - - Lemma mx_eval_ker_proper : - proper_ideal (mx_eval_ker_pred). - Proof. - split. - - by rewrite /mx_eval_ker_pred /in_mem /mem /= mx_eval_1 oner_neq0. - - move => a b. - rewrite /in_mem /=. - rewrite /mx_eval_ker_pred. - case: (mx_eval_is_rmorphism) => _ -> /eqP->. - by rewrite mulr0. - Qed. - - Lemma mx_eval_ker_zmod : - zmodPred (mx_eval_ker_pred). - Proof. - constructor. - - constructor; [constructor|]. - constructor. - + rewrite /in_mem //= /mx_eval_ker_pred /mx_eval /map_mx. - apply/eqP/matrixP=>a b. - rewrite !mxE. - unfold map_poly. - generalize (hornerC (RtoC 0) (vals a b)); intros. - rewrite poly_size_0. - rewrite (eq_poly (fun _ => 0)). - * rewrite -{2}(horner0 (vals a b)). - f_equal. - apply /polyP => i /=. - rewrite coef_poly coefC /=. - by case: (i == 0)%nat. - * move=> i ilt. - rewrite coefC. - by case: (i == 0)%nat. - + rewrite /in_mem //= /prop_in2 /mx_eval_ker_pred => a b. - rewrite /in_mem /mem /= . - generalize (raddfD (mx_eval_rmorphism) a b); intros. - simpl in H; rewrite H. - revert H0 H1. - move => /eqP -> /eqP->. - rewrite add0r //. - - rewrite /Pred.Exports.oppr_closed /mem /= /mx_eval_ker_pred => a. - rewrite /in_mem /= => /eqP-eqq1. - generalize (raddfN (mx_eval_rmorphism) a); intros. - simpl in H. - rewrite H eqq1 oppr0 //. - Qed. - - Definition mx_eval_ker_is_ideal : - idealr mx_eval_ker_pred - := MkIdeal mx_eval_ker_zmod mx_eval_ker_proper. - - Canonical mx_eval_ker_ideal := KeyedPred (mx_eval_ker_is_ideal). - - Definition mx_eval_ker_quot_ring - := { ideal_quot mx_eval_ker_ideal }. - - Local Open Scope quotient_scope. - - Definition mx_eval_quot : mx_eval_ker_quot_ring -> MR_comRingType 0 n - := lift_fun1 mx_eval_ker_quot_ring mx_eval. - - Lemma pi_mx_eval_quot : {mono (\pi_mx_eval_ker_quot_ring) : x / mx_eval x >-> mx_eval_quot x}. - Proof. - move=> x. - rewrite /mx_eval_quot -eq_lock. - case: piP => a /EquivQuot.eqmodP. - rewrite /Quotient.equiv_equiv /Quotient.equiv /in_mem /mem /= /mx_eval_ker_pred. - destruct mx_eval_is_rmorphism. - rewrite base => eqq. - move=> /eqP in eqq. - apply (f_equal (fun z => z + mx_eval a)) in eqq. - by rewrite -addrA add0r (addrC _ (mx_eval a)) addrN addr0 in eqq. - Qed. - - Lemma mx_eval_quotC c : - mx_eval_quot (\pi_({ideal_quot mx_eval_ker_ideal}) c%:P) = const_mx (RtoC c). - Proof. - by rewrite pi_mx_eval_quot mx_evalC. - Qed. - - Lemma mx_eval_quot1 : mx_eval_quot 1 = 1. - Proof. - rewrite /one /= /Quotient.one /= /one /= /locked. - destruct master_key. - by rewrite mx_eval_quotC. - Qed. - - Lemma mx_eval_quot_is_rmorphism : rmorphism mx_eval_quot. - Proof. - split => [x|]. - - apply quotP=> y <-. - revert x. - apply quotP => x <-. - rewrite !reprK. - rewrite !pi_mx_eval_quot. - rewrite /mx_eval_quot -!eq_lock. - rewrite -pi_is_additive. - case: piP => y' /eqquotP. - rewrite /mx_eval_ker_pred /in_mem/mem/= => /eqP. - generalize (raddfD (mx_eval_rmorphism)); intros mx_eval_add. - generalize (raddfN (mx_eval_rmorphism)); intros mx_eval_opp. - generalize (mx_eval_add (x-y) (-y')); intros add1. - simpl in add1; rewrite add1. - specialize (mx_eval_add x (-y)); simpl in mx_eval_add. - rewrite mx_eval_add. - generalize (mx_eval_opp y); intro opp1. - simpl in opp1; rewrite opp1. - specialize (mx_eval_opp y'); simpl in mx_eval_opp. - rewrite mx_eval_opp. - intro HH. - apply (f_equal (fun z => z + mx_eval y')) in HH. - rewrite add0r -!addrA in HH. - rewrite (addrC _ (mx_eval y')) addrN addr0 in HH. - by rewrite -HH. - - constructor. - + move => x. - apply quotP=> y <-. - revert x. - apply quotP => x <-. - rewrite !reprK. - rewrite !pi_mx_eval_quot. - rewrite /mx_eval_quot -!eq_lock. - rewrite -pi_is_multiplicative. - case: piP => y' /eqquotP. - rewrite /mx_eval_ker_pred /in_mem/mem/= => /eqP. - destruct mx_eval_is_rmorphism as [? [??]]. - specialize (base (x * y) y'); simpl in base. - rewrite base. - specialize (m x y); simpl in m. - rewrite m. - intro HH. - apply (f_equal (fun z => z + mx_eval y')) in HH. - rewrite add0r -!addrA in HH. - rewrite (addrC _ (mx_eval y')) addrN addr0 in HH. - by rewrite -HH. - + by apply mx_eval_quot1. - Qed. - - Lemma mx_eval_quot_is_injective : - injective mx_eval_quot. - Proof. - intros x y. - rewrite /mx_eval_quot -!eq_lock. - rewrite -{2}[x]reprK -{2}[y]reprK. - move: (repr x) (repr y) => {x} {y} x y eqq. - apply/eqquotP. - rewrite /Quotient.equiv/=. - rewrite /mx_eval_ker_pred /in_mem/mem/=. - apply/eqP. - destruct mx_eval_is_rmorphism. - specialize (base x y). - simpl in base. - by rewrite eqq addrN in base. - Qed. - - Lemma poly_eval_mod (a b p : {poly R}) (c : R[i]) : - a %% p = b %% p -> - (map_poly RtoC p).[c] = 0 -> - (map_poly RtoC a).[c] = (map_poly RtoC b).[c]. - Proof. - intros. - rewrite (divp_eq a p). - rewrite (divp_eq b p) H. - by rewrite !rmorphD !rmorphM !hornerD !hornerM H0 !mulr0 !add0r. - Qed. - - Lemma mx_eval_is_surjective : - let charvals := [seq characteristic_polynomial (vals 0 j) | j : 'I_n.+1] in - pairwise (coprimep (R := R_fieldType)) charvals -> - (forall j, Im (vals 0 j) != 0) -> - forall (c : MR_comRingType 0 n), - {x : {poly R} | mx_eval x = c}. - Proof. - intros charvals cop imn0 c. - - pose rvals := [seq sval (peval_C_decomp (vals 0 j) (c 0 j) (imn0 j)) | j : 'I_n.+1]. - have rvals_prop: forall (j:'I_n.+1), peval_C (rvals`_j) (vals 0 j) = (c 0 j). - { - subst rvals=>j. - rewrite (nth_map 0)/=. - - move: (svalP (peval_C_decomp (vals 0 (enum 'I_n.+1)`_j) (c 0 (enum 'I_n.+1)`_j) (imn0 (enum 'I_n.+1)`_j))). - by rewrite !nth_ord_enum. - - rewrite size_enum_ord. - by destruct j. - } - - have eqsize: size rvals = size charvals. - { - subst rvals charvals. - by rewrite !size_map -enumT size_enum_ord. - } - - generalize (chinesep_list_prop R_fieldType (zip rvals charvals)); intros. - assert ([seq i.2 | i <- zip rvals charvals] = charvals). - { - apply (@eq_from_nth _ 0). - - by rewrite size_map size_zip eqsize ?minnn. - - intros. - rewrite size_map in H. - rewrite (nth_map 0) //. - by rewrite nth_zip_cond H. - } - rewrite -H in cop. - specialize (X cop). - destruct X. - exists x. - apply /matrixP. - intros ??. - pose p := (zip rvals charvals)`_y. - assert (p \in zip rvals charvals). - { - subst p. - apply/(nthP 0). - subst rvals charvals. - exists y => //. - by rewrite size_zip eqsize minnn size_map size_enum_ord ltn_ord. - } - specialize (e p H0). - rewrite ord1 -(rvals_prop y). - rewrite /mx_eval /peval_C. - rewrite !mxE. - have eqq1: charvals`_y = p.2. - { - subst p. - by rewrite nth_zip. - } - have eqq2: rvals`_y = p.1. - { - subst p. - by rewrite nth_zip. - } - rewrite eqq2. - apply (poly_eval_mod _ _ _ _ e). - rewrite -eqq1. - assert (charvals`_y = characteristic_polynomial (vals 0 y)). - { - unfold charvals. - rewrite (nth_map 0)/=. - - by rewrite nth_ord_enum. - - by rewrite size_enum_ord ltn_ord. - } - rewrite H1. - apply characteristic_polynomial_correct. - Qed. - - Lemma mx_eval_quot_is_surjective : - let charvals := [seq characteristic_polynomial (vals 0 j) | j : 'I_n.+1] in - pairwise (coprimep (R := R_fieldType)) charvals -> - (forall j, Im (vals 0 j) != 0) -> - forall (c : MR_comRingType 0 n), - {x : mx_eval_ker_quot_ring | mx_eval_quot x = c}. - Proof. - intros. - destruct (mx_eval_is_surjective H H0 c). - exists (\pi_mx_eval_ker_quot_ring x). - by rewrite pi_mx_eval_quot. - Qed. - - Lemma mx_eval_quot_is_bijective : - let charvals := [seq characteristic_polynomial (vals 0 j) | j : 'I_n.+1] in - pairwise (coprimep (R := R_fieldType)) charvals -> - (forall j, Im (vals 0 j) != 0) -> - bijective mx_eval_quot. - Proof. - intros charvals cop imn0. - pose g : MR_comRingType 0 n -> mx_eval_ker_quot_ring := - fun c => sval (mx_eval_quot_is_surjective cop imn0 c). - apply Bijective with (g := g). - - intros ?. - assert (mx_eval_quot (g (mx_eval_quot x)) = mx_eval_quot x). - { - rewrite /g. - destruct (mx_eval_quot_is_surjective cop imn0 (mx_eval_quot x)). - by simpl. - } - by apply mx_eval_quot_is_injective in H. - - intros ?. - rewrite /g. - destruct (mx_eval_quot_is_surjective cop imn0 x). - by simpl. - Qed. - - -End eval_vectors. - -Definition odd_nth_roots' n := - \row_(j < (S (proj1_sig (pow2_S n)))) - (nth_root (2 * j + 1) (2 ^ (S n))). - -Lemma odd_nth_roots_minpoly n : - forall i, - root ('X^(2^n) + 1%:P) (odd_nth_roots n 0 i). -Proof. - move=> i. - rewrite /odd_nth_roots mxE /root hornerD hornerXn hornerC. - generalize (odd_roots_prim i n); intros. - apply (f_equal (fun z => C1 + z)) in H. - rewrite addrN in H. - rewrite addrC in H. - apply /eqP. - unfold zero in *; simpl in *. - unfold real_complex_def in *. - unfold zero in *; simpl in *. - by rewrite <- H. -Qed. - -Lemma odd_nth_roots_minpoly_complex n : - forall (c : R[i]), - root ('X^(2^(S n)) + 1%:P) c -> - Im c <> 0. -Proof. - intros. - unfold root in H. - rewrite hornerD hornerXn hornerC in H. - move /eqP in H. - unfold not; intros. - destruct c. - simpl in H0. - rewrite H0 in H. - assert (forall x:R, x^+2 + 1 <> 0). - { - intros. - apply Rgt_not_eq. - generalize (pow2_ge_0 x); intros. - rewrite /one /add /zero /= -RpowE. - coq_lra. - } - assert (forall x:R, x^+(2^(S n)) + 1 <> 0). - { - intros. - by rewrite expnS mulnC exprM. - } - clear H0 H1. - replace ((Re +i* 0)%C ^+ (2 ^ n.+1)) with (RtoC (Re ^+ (2 ^ n.+1))) in H. - - unfold RtoC in H. - move /eqP in H. - rewrite eq_complex in H. - move /andP in H. - destruct H. - simpl in H. - move /eqP in H. - by specialize (H2 Re). - - assert (forall n, - RtoC (Re ^+ n) = (Re +i* 0)%C ^+n). - { - induction n0. - - by rewrite !expr0. - - rewrite !exprS. - assert (forall (x y : R), - RtoC (x * y) = RtoC x * RtoC y). - { - apply RtoC_is_rmorphism. - } - by rewrite H0 IHn0. - } - apply H0. - Qed. - -Lemma minpoly_mult_odd_nth_roots n (p : {poly R[i]}) : - Pdiv.Ring.rmodp p ('X^(2^n) + 1%:P) = 0 -> - forall i, root p (odd_nth_roots n 0 i). -Proof. -intros. -move=> /Pdiv.Ring.rmodp_eq0P in H. -rewrite Pdiv.ComRing.rdvdp_eq in H. -destruct (pow2_S n). -generalize (Xn_add_c_monic (R:=ComplexField.complex_comRingType R_fieldType) x 1); intros. -rewrite monicE in H0. -move=> /eqP in H0. -rewrite -(eqP i0) in H0. -rewrite H0 in H. -rewrite Theory.expr1n in H. -move=> /eqP in H. -rewrite Theory.scale1r in H. -rewrite H. -unfold root. -apply /eqP. -rewrite hornerM. -generalize (odd_nth_roots_minpoly n i); intros. -unfold root in H1. -move=> /eqP in H1. -rewrite H1. -rewrite mulr0 //. -Qed. - - -Lemma drop_poly_opp [S : ringType] n (p : {poly S}) : - drop_poly n (- p) = - drop_poly n p. -Proof. - rewrite -!scaleN1r. - apply drop_polyZ. -Qed. - -Lemma drop_poly_diff [S : ringType] n (p q : {poly S}) : - drop_poly n p = drop_poly n q -> - seq.size (p - q) <= n. -Proof. - intros. - assert (drop_poly n (p - q) = 0). - { - by rewrite drop_polyD drop_poly_opp -H addrN. - } - generalize (poly_take_drop n (p - q)); intro decomp. - rewrite H0 mul0r addr0 in decomp. - rewrite -decomp. - apply size_take_poly. -Qed. - -Lemma monic_size_pos [S : ringType] (p : {poly S}) : - p \is monic -> - seq.size p > 0. -Proof. - case: ltP => // n0lt. - have: seq.size p == 0%nat by lia. - rewrite size_poly_eq0 => eqq1 /monic_neq0. - by rewrite eqq1. -Qed. - -Lemma monic_drop_n_1 [S : ringType] n (p : {poly S}) : - p \is monic -> - seq.size p = n.+1 -> - drop_poly n p = 1. -Proof. - rewrite monicE /lead_coef. - rewrite /drop_poly poly_def => /[swap] -> /=. - replace (n.+1 - n)%nat with 1%nat by lia. - rewrite big_ord1 add0n => /eqP->. - by rewrite expr0 alg_polyC. -Qed. - -Lemma monic_dif_same_deg (p q : {poly R[i]}) : - p \is monic -> - q \is monic -> - seq.size p = seq.size q -> - seq.size (q - p) < seq.size p. -Proof. - intros. - pose (n := seq.size p). - pose (n1 := (n-1)%nat). - assert (n = S n1). - { - generalize (monic_size_pos p H); lia. - } - unfold n in H2. - rewrite H2. - generalize (drop_poly_diff n1 q p); intros. - apply H3. - rewrite monic_drop_n_1; trivial. - - rewrite monic_drop_n_1; trivial. - - rewrite -H1 H2 //. - Qed. - -Lemma monic_divides_same_deg (p q : {poly R[i]}) : - p \is monic -> - q \is monic -> - seq.size p = seq.size q -> - Pdiv.Ring.rdvdp (R:=ComplexField.complex_unitRingType R_realFieldType) - p q -> - p = q. -Proof. - intros. - generalize (Pdiv.RingMonic.rdivp_eq H q); intros. - generalize (monic_dif_same_deg _ _ H H0 H1); intros. - generalize (Pdiv.RingMonic.redivp_eq H 1 H4); intros. - rewrite mul1r (addrC q _) addrA addrN add0r in H5. - assert (Pdiv.CommonRing.rdivp (R:=ComplexField.complex_ringType R_fieldType) q p = 1). - { - unfold Pdiv.CommonRing.rdivp. - by rewrite H5 /=. - } - rewrite H6 mul1r in H3. - unfold Pdiv.Ring.rdvdp in H2. - move=> /eqP in H2. - rewrite H2 addr0 in H3. - by symmetry. -Qed. - -Lemma seq_all_odd_roots n (p : {poly R[i]}) : - let rs := MatrixFormula.seq_of_rV (odd_nth_roots n) in - (forall i, root p (odd_nth_roots n 0 i)) -> - seq.all (root p) (MatrixFormula.seq_of_rV (odd_nth_roots n)). -Proof. - intros. - apply/tuple.all_tnthP => /= i. - rewrite finfun.tnth_fgraph /= finfun.ffunE. - by apply H. -Qed. - -Lemma odd_roots_uniq' n : - forall i j, - odd_nth_roots n 0 i = odd_nth_roots n 0 j -> - i = j. -Proof. - rewrite /odd_nth_roots => i j. - rewrite !mxE. - destruct (pow2_S (S n)) as [? eqq1]. - rewrite (eqP eqq1) => /nth_root_eq. - rewrite -(eqP eqq1). - rewrite !Nat.mod_small /=. - - move/addIn => eqq2. - apply/ord_inj. - lia. - - rewrite expnS. - destruct i; destruct j; simpl in *. - lia. - - rewrite expnS. - destruct i; destruct j; simpl in *. - lia. -Qed. - -Lemma odd_roots_uniq n : - let rs := MatrixFormula.seq_of_rV (odd_nth_roots n) in - uniq_roots (MatrixFormula.seq_of_rV (odd_nth_roots n)). -Proof. - rewrite uniq_rootsE /= /odd_nth_roots. - rewrite /MatrixFormula.seq_of_rV. - apply /tuple.tuple_uniqP => i j. - rewrite !finfun.tnth_fgraph !finfun.ffunE => eqq1. - apply odd_roots_uniq' in eqq1. - by apply enum_val_inj. -Qed. - -Lemma odd_nth_roots_minpoly_mult n (p : {poly R[i]}) : - (forall i, root p (odd_nth_roots n 0 i)) -> - Pdiv.Ring.rmodp p ('X^(2^n) + 1%:P) = 0 . -Proof. - intros roots. - move: (seq_all_odd_roots n p roots) (odd_roots_uniq n) => allroot uniqroot. - generalize (Pdiv.UnitRing.uniq_roots_rdvdp allroot uniqroot); intros. - pose (rs := MatrixFormula.seq_of_rV (odd_nth_roots n)). - assert (\prod_(z <- rs) ('X - z%:P) = 'X^(2 ^ n) + 1%:P). - { - apply monic_divides_same_deg. - - apply monic_prod_XsubC. - - destruct (pow2_S n). - rewrite (eqP i). - apply Xn_add_c_monic. - - rewrite size_prod_XsubC. - destruct (pow2_S n). - rewrite (eqP i). - rewrite size_Xn_addC. - rewrite -(eqP i). - by rewrite MatrixFormula.size_seq_of_rV. - - generalize (seq_all_odd_roots n _ (odd_nth_roots_minpoly n)); intros allroot2. - apply (Pdiv.UnitRing.uniq_roots_rdvdp allroot2 uniqroot). - } - rewrite H0 in H. - unfold Pdiv.Ring.rdvdp in H. - by move=> /eqP in H. - Qed. - -Lemma map_RtoC_size (p : {poly R}) : - seq.size p = seq.size (map_poly RtoC p). -Proof. - by rewrite (size_map_poly RtoC_rmorphism p). -Qed. - -Lemma map_RtoC_lead_coef (p : {poly R}) : - lead_coef (map_poly RtoC p) = RtoC (lead_coef p). -Proof. - by rewrite (lead_coef_map RtoC_rmorphism p). -Qed. - -Lemma rmodp_RtoC_morph : {morph (map_poly RtoC) : p q / Pdiv.Ring.rmodp p q }. -Proof. - rewrite /Pdiv.Ring.rmodp => p q. - by rewrite redivp_map. -Qed. - -Lemma RtoC_inj : injective RtoC. -Proof. - rewrite /RtoC=>a b. - by move=> []. -Qed. - -Lemma map_poly_RtoC_eq0E p : map_poly RtoC p = 0 <-> (p = 0). -Proof. - split. - - rewrite -(map_poly0 RtoC). - move/map_inj_poly. - apply => //. - by apply RtoC_inj. - - move=> ->. - by rewrite map_poly0. -Qed. - -Lemma rmodp_R (p q : {poly R}) : - Pdiv.Ring.rmodp p q = 0 <-> Pdiv.Ring.rmodp (map_poly RtoC p) (map_poly RtoC q) = 0. -Proof. - rewrite -rmodp_RtoC_morph. - symmetry. - apply map_poly_RtoC_eq0E. -Qed. - -Lemma map_poly_add_RtoC (p q : {poly R}) : - map_poly RtoC (p + q) = (map_poly RtoC p) + (map_poly RtoC q). -Proof. - apply (raddfD map_RtoC_rmorphism). -Qed. - -Lemma map_RtoC_Xnpoly n : - map_poly (aR:=R_ringType) (rR:=ComplexField.complex_ringType R_fieldType) RtoC - ('X^(2 ^ n) + 1%:P) = 'X^(2 ^ n) + 1%:P. -Proof. - rewrite map_poly_add_RtoC map_polyXn. - f_equal. - by rewrite (map_polyC RtoC_rmorphism 1). -Qed. - -Lemma odd_nth_roots_minpoly_mult_R n (p : {poly R}) : - (forall i, root (map_poly RtoC p) (odd_nth_roots n 0 i)) -> - Pdiv.Ring.rmodp p ('X^(2^n) + 1%:P) = 0. -Proof. - intros. - generalize (odd_nth_roots_minpoly_mult n (map_poly RtoC p) H); intros. - rewrite rmodp_R. - rewrite <- H0. - f_equal. - apply map_RtoC_Xnpoly. -Qed. - -Lemma minpoly_mult_odd_nth_roots_R n (p : {poly R}) : - Pdiv.Ring.rmodp p ('X^(2^n) + 1%:P) = 0 -> - forall i, root (map_poly RtoC p) (odd_nth_roots n 0 i). -Proof. - intros. - rewrite rmodp_R in H. - rewrite map_RtoC_Xnpoly in H. - by generalize (minpoly_mult_odd_nth_roots n (map_poly RtoC p) H). -Qed. - -Lemma odd_nth_roots_quot n (p : {poly R}) : - mx_eval_ker_pred (odd_nth_roots' n) p <-> - princ_ideal_pred ('X^(2^n) + 1%:P) p. -Proof. - unfold mx_eval_ker_pred, princ_ideal_pred. - split; intros. - - apply /eqP. - apply odd_nth_roots_minpoly_mult_R. - intros. - move=> /eqP in H. - apply matrixP in H. - unfold odd_nth_roots' in H. - unfold odd_nth_roots. - unfold root. - destruct i. - assert (m < (sval (pow2_S n)).+1) by (simpl; lia). - specialize (H 0 (Ordinal H0)). - simpl in H. - rewrite !mxE in H. - rewrite !mxE. - replace (2 * Ordinal (n:=2 ^ n) (m:=m) i + 1)%N with - (2 * Ordinal (n:=(2 ^ n - 1).+1) (m:=m) H0 + 1)%N. - + by rewrite H. - + f_equal. - - unfold zero; simpl. - unfold const_mx. - apply /eqP. - apply matrixP. - intros ??. - rewrite !mxE. - move=> /eqP in H. - generalize (minpoly_mult_odd_nth_roots_R n p H); intros. - unfold pow2_S in y. - destruct y as [y' ?]. - assert (y' < 2^n) by lia. - specialize (H0 (Ordinal H1)). - unfold root in H0. - move=> /eqP in H0. - simpl in H0. - unfold odd_nth_roots in H0. - rewrite mxE in H0. - apply H0. - Qed. - - -Section norms. - - Import matrix_ring. - - Implicit Types x y : R[i]. - - Notation normc := ComplexField.Normc.normc. - - Definition norm1 {n} (v : 'rV[R[i]]_n):R := \sum_(j < n) normc (v 0 j). - - Definition norm_inf {n} (v : 'rV[R[i]]_n):R := \big[Order.max/0]_(j < n) normc (v 0 j). - - Definition cvec_norm_inf {n} (v : 'cV[R[i]]_n):R := norm_inf(v^T). - Definition cvec_norm1 {n} (v : 'cV[R[i]]_n):R := norm1(v^T). - - Definition matrix_norm_inf {n m} (mat : 'M[R[i]]_(n,m)) := - \big[Order.max/0]_(j i _. - rewrite (bigD1 i) // /= big1_idem ?addr0 //. - - by rewrite /diag_mx mxE eq_refl /= mulr1n. - - intros. - rewrite mxE eq_sym. - by rewrite (negbTE H) ComplexField.Normc.normc0. - Qed. - - Lemma normc_nneg (x : R[i]) : - (R0 <= normc x)%O. - Proof. - rewrite /normc. - case: x => r i. - apply ssrnum.Num.Theory.sqrtr_ge0. - Qed. - - Lemma sum_normc_nneg {n} (v : 'rV[R[i]]_n) : - ((@zero R_zmodType) <= \sum_(k < n) normc (R:=R_rcfType) (v ord0 k))%O. - Proof. - apply big_rec => // i x _ xnneg. - by rewrite ssrnum.Num.Theory.addr_ge0 // normc_nneg. - Qed. - - Lemma mat_vec_norm_inf {n} (v : 'rV[R[i]]_n) : - norm_inf v = matrix_norm_inf (v^T). - Proof. - rewrite /norm_inf /matrix_norm_inf /=. - apply eq_bigr => j _. - by rewrite big_ord_recl big_ord0 addr0 mxE. - Qed. - - Lemma mat_vec_norm1 {n} (v : 'rV[R[i]]_n) : - norm1 v = matrix_norm1 (v^T). - Proof. - rewrite /norm1 /matrix_norm1 /matrix_norm_inf /=. - rewrite big_ord_recl big_ord0 Order.POrderTheory.max_l ?sum_normc_nneg //. - by apply eq_bigr => j _; rewrite !mxE. - Qed. - - - Lemma R00 : R0 = 0. - Proof. - by []. - Qed. - - - Lemma normc_nnegR (x : R[i]) : - Rle R0 (normc x). - Proof. - move: (normc_nneg x). - rewrite /Order.le /=. - by move/RlebP. - Qed. - - Lemma maxrM (c a b : R) : Rle 0 c -> Order.max (c * a) (c * b) = c * Order.max a b. - Proof. - rewrite /Order.max /Order.lt /=. - (repeat case: RltbP); [lra | | | lra]; intros. - - destruct H. - + apply Rmult_lt_reg_l in p => //. - + subst. - by rewrite !mul0r. - - destruct H. - + elim n. - by apply Rmult_lt_compat_l. - + subst. - by rewrite !mul0r. - Qed. - - Lemma maxrM_l (c a b : R) : Rle 0 c -> Order.max (a * c) (b * c) = (Order.max a b)*c. - Proof. - rewrite /Order.max /Order.lt /=. - (repeat case: RltbP); [lra | | | lra]; intros. - - destruct H. - + apply Rmult_lt_reg_r in p => //. - + subst. - by rewrite !mulr0. - - destruct H. - + elim n. - by apply Rmult_lt_compat_r. - + subst. - by rewrite !mulr0. - Qed. - - Lemma norm_infZ {n} (c : R[i]) (v : 'rV[R[i]]_n) : - norm_inf (c *: v) = (normc c) * norm_inf v. - Proof. - rewrite /norm_inf. - apply (big_rec2 (fun a b => a = normc c * b)). - - by rewrite mulr0. - - move=> i a b _ ->. - rewrite mxE ComplexField.Normc.normcM maxrM //. - apply normc_nnegR. - Qed. - - Lemma norm_inf_nneg {n} (v : 'rV[R[i]]_n) : - (@zero R_ringType <= norm_inf v)%O. - Proof. - rewrite /norm_inf. - apply big_rec => //= i x _ xn. - by rewrite Order.TotalTheory.le_maxr xn orbT. - Qed. - - Lemma normc_triang (x y : R[i]) : - (normc (x + y) <= normc x + normc y)%O. - Proof. - generalize (ComplexField.lec_normD x y); intros. - rewrite /ComplexField.lec /= addr0 in H. - move => /andP in H. - by destruct H. - Qed. - - Lemma normc_triangR (x y : R[i]) : - Rle (normc (x + y)) (normc x + normc y). - Proof. - move: (normc_triang x y) => /=. - rewrite /Order.le /Order.POrder.le /=. - by move/RleP. - Qed. - - Lemma normc_triang_sum {n} (a : 'I_n -> R[i]) : - Rleb (normc (\sum_(j R) (c : R) : - (\sum_(j R) (c : R) : - c * (\sum_(j R) (c : R) : - Rle 0 c -> - (\big[Order.max/0]_(j R) (c : R) : - Rle 0 c -> - c * (\big[Order.max/0]_(j R) : - (forall j, Rleb (a j) (b j)) -> - Rleb (\sum_(j - Rleb c d -> - Rleb (Order.max a c) (Order.max b d). - Proof. - rewrite -!RmaxE. - intros. - apply /RlebP. - move /RlebP in H. - move /RlebP in H0. - apply Rmax_case; apply Rmax_Rle. - - by left. - - by right. - Qed. - - Lemma bigmax_le2 {n} (a b : 'I_n -> R) : - (forall j, Rleb (a j) (b j)) -> - Rleb (\big[Order.max/0]_(jR): - Order.le init (\big[Order.max/init]_(j < n) f (v 0 j)). - Proof. - rewrite BigOp.bigopE. - unlock reducebig. - elim: (index_enum _) => /=. - - by exact: Order.POrderTheory.lexx. - - move=> a l. - rewrite Order.TotalTheory.le_maxr => ->. - by rewrite orbT. - Qed. - - Lemma omax_l {disp : Datatypes.unit} {T : porderType disp} (x y:T) : Order.le x (Order.max x y). - Proof. - rewrite Order.POrderTheory.maxEle. - by case_eq (Order.le x y). - Qed. - - Lemma omax_r {disp : Datatypes.unit} {T : orderType disp} (x y:T) : Order.le y (Order.max x y). - Proof. - rewrite Order.POrderTheory.maxElt. - case_eq (Order.lt x y) => //. - rewrite (Order.TotalTheory.leNgt y x). - by move/negbT. - Qed. - - Lemma omax_l_real (x y:R) : Rleb x (Order.max x y). - Proof. - by exact: (omax_l x y). - Qed. - - Lemma omax_r_real (x y:R) : Rleb y (Order.max x y). - Proof. - by exact: (omax_r x y). - Qed. - - Lemma bigmaxr_le {n} (v : 'rV[R[i]]_n) init f i: - Rleb (f (v 0 i)) (\big[Order.max/init]_(j < n) f (v 0 j)). - Proof. - rewrite BigOp.bigopE. - unlock reducebig. - move: (mem_index_enum i). - elim: (index_enum _) => /= [| a l IHl]. - - by rewrite seq.in_nil. - - rewrite seq.in_cons => /orP [/eqP->| ]. - + by exact: omax_l_real. - + move=> inn. - move: (IHl inn) => /RlebP=>le1. - apply/RlebP. - eapply Rle_trans; try apply le1. - apply/RlebP. - apply omax_r_real. - Qed. - - Lemma bigmaxr_le_alt {n} (v : 'I_n -> R) init i: - Rleb (v i) (\big[Order.max/init]_(j < n) (v j)). - Proof. - rewrite BigOp.bigopE. - unlock reducebig. - move: (mem_index_enum i). - elim: (index_enum _) => /= [| a l IHl]. - - by rewrite seq.in_nil. - - rewrite seq.in_cons => /orP [/eqP->| ]. - + by exact: omax_l_real. - + move=> inn. - move: (IHl inn) => /RlebP=>le1. - apply/RlebP. - eapply Rle_trans; try apply le1. - apply/RlebP. - apply omax_r_real. - Qed. - - Lemma mat_vec_norm_bound1 {n m} - (mat : 'M[R[i]]_(n, m)) - (vec : 'rV[R[i]]_m) k: - Rleb (normc (\sum_(j j. - rewrite ComplexField.Normc.normcM. - apply /RlebP. - apply Rmult_le_compat_l. - + apply normc_nnegR. - + apply /RlebP. - apply bigmaxr_le. - Qed. - - - Lemma bigmax_nneg {n} (v : 'I_n -> R) : - (forall i, Rleb 0 (v i)) -> - Rleb 0 (\big[Order.max/0]_(j < n) (v j)). - Proof. - intros. - apply big_rec. - - apply /RlebP. - unfold zero; simpl. - coq_lra. - - intros. - assert ((zero R_ringType) <= Order.max (v i) x)%O. - { - rewrite Order.TotalTheory.le_maxr. - apply /orP. - left. - apply H. - } - by rewrite /Order.le /= in H1. - Qed. - - Lemma bigmax_normc_nneg {n} (v : 'rV[R[i]]_n): - Rleb 0 (\big[Order.max/0]_(j < n) normc (v 0 j)). - Proof. - apply bigmax_nneg => i. - apply normc_nneg. - Qed. - - Lemma sum_nneg {n} (v : 'I_n -> R) : - (forall i, Rleb 0 (v i)) -> - Rleb 0 (\sum_(j < n) (v j)). - Proof. - intros. - apply big_rec. - - apply /RlebP. - unfold zero; simpl. - coq_lra. - - intros. - apply /RlebP. - rewrite /add /=. - apply Rplus_le_le_0_compat; by apply /RlebP. - Qed. - - Lemma matrix_vec_norm_inf_sub_mult {n m} - (mat : 'M[R[i]]_(n, m)) - (vec : 'rV[R[i]]_m) : - Rleb (matrix_norm_inf (mat *m (vec^T))) - ((matrix_norm_inf mat) * (norm_inf vec)). - Proof. - rewrite /matrix_norm_inf /norm_inf /mulmx /=. - generalize (@max_mult_distr n); intros. - rewrite /mul /= in H. - rewrite H. - - apply bigmax_le2. - intros. - rewrite big_ord_recl big_ord0 addr0 mxE /trmx /=. - under eq_bigr do rewrite mxE. - by rewrite mat_vec_norm_bound1. - - apply /RlebP. - apply bigmax_normc_nneg. - Qed. - - Lemma sum_plus {n} (a b : 'I_n -> R) : - \sum_(i 'I_m -> R) : - \sum_(i j. - under eq_bigr do rewrite mxE. - apply /RlebP. - apply Rle_trans with - (r2 := \sum_(i < p) (\sum_(j0 < m) normc (R:=R_rcfType) (mat1 j j0 * mat2 j0 i))). - + apply /RlebP. - apply sum_le => j0. - apply normc_triang_sum. - + rewrite exchange_sums. - generalize (@sum_mult_distr); intros. - rewrite /mul /= in H0. - rewrite H0. - apply /RlebP. - apply sum_le => j0. - replace (\sum_(i < p) normc (R:=R_rcfType) (mat1 j j0 * mat2 j0 i)) with - ((normc (R:=R_rcfType) (mat1 j j0)) * - (\sum_(i < p) normc (R:=R_rcfType) (mat2 j0 i))). - * apply /RlebP. - apply Rmult_le_compat_l. - -- apply normc_nnegR. - -- apply /RlebP. - generalize (@bigmaxr_le_alt m - (fun j0 => (\sum_(i < p) normc (R:=R_rcfType) (mat2 j0 i))) 0 j0); intros. - apply H1. - * rewrite mulrC sum_mult_distr. - apply eq_bigr => i _. - by rewrite mulrC ComplexField.Normc.normcM. - - apply /RlebP. - apply bigmax_nneg => i. - apply sum_nneg => k. - apply normc_nneg. - Qed. - - Lemma matrix_norm_inf_scale {n m} (mat : 'M[R[i]]_(n,m)) (c : R[i]) : - matrix_norm_inf (scalemx c mat) = (normc c)*(matrix_norm_inf mat). - Proof. - rewrite /matrix_norm_inf /scalemx max_mult_distl. - - apply eq_bigr => j _. - rewrite sum_mult_distl. - apply eq_bigr => k _. - by rewrite mxE ComplexField.Normc.normcM. - - apply normc_nnegR. - Qed. - - Lemma big_max_const (n:nat) (c : R) : n != 0%nat -> - Rle 0 c -> - \big[Order.max/0]_(j < n) c = c. - Proof. - case: n; [lia |intros n _]. - induction n. - - rewrite big_ord_recl big_ord0 => H. - rewrite Order.POrderTheory.max_l //. - apply /RlebP. - apply H. - - rewrite big_ord_recl => H. - rewrite IHn. - rewrite Order.POrderTheory.max_l //. - apply H. - Qed. - - Lemma normc_conj (x : R[i]) : - ComplexField.Normc.normc x = ComplexField.Normc.normc (conjc x). - Proof. - case: x => rx ix /=. - by rewrite sqrrN. - Qed. - - Lemma normc_nth_root j (n:nat) : - n != 0%nat -> - normc (nth_root j n) = 1. - Proof. - rewrite /normc /nth_root => nN0. - rewrite -!RpowE -!Rsqr_pow2 addrC /add /=. - by rewrite sin2_cos2 ssrnum.Num.Theory.sqrtr1. - Qed. - - Lemma big_max_const_fun (n : nat) (a : 'I_n -> R) (c : R) : n != 0%nat -> - Rle 0 c -> - (forall i, a i = c) -> - \big[Order.max/0]_(i < n) (a i) = c. - Proof. - intros. - under eq_bigr do rewrite H1. - by apply big_max_const. - Qed. - - Lemma norm_inf_const_norm (n : nat) (vec : 'rV[R[i]]_n.+1) : - (forall i , normc (R:=R_rcfType) (vec 0 i) = 1) -> - norm_inf vec = 1. - Proof. - intros. - rewrite /norm_inf. - apply big_max_const_fun; trivial. - rewrite /one /=; coq_lra. - Qed. - - Lemma pow2n0 n : (2 ^ n)%N != 0%N. - Proof. - by rewrite expn_eq0. - Qed. - - Hint Immediate pow2n0. - - Lemma norm_inf_conj_half_roots (n : nat) : - norm_inf (map_mx conjc (nth_roots_half n.+1)) = 1. - Proof. - rewrite /norm_inf /nth_roots_half. - apply big_max_const_fun. - - apply pow2n0. - - rewrite /one/=; coq_lra. - - intros. - by rewrite !mxE -normc_conj normc_nth_root // pow2n0. - Qed. - - Lemma big_sum_const (n : nat) (c : R) : - \sum_(j i0 _. - rewrite /even_nth_roots !mxE -exp_conj -normc_conj pow_nth_root' ?normc_nth_root. - over. - apply pow2n0. - apply pow2n0. - by rewrite big_sum_const. - Qed. - - Lemma encode_mat_norm_inf (n : nat) : - let pmat := peval_mat (odd_nth_roots (S n)) in - let encmat := (conj_mat (pmat^T)) in - Rleb (matrix_norm_inf encmat) (2^S n)%:R. - Proof. - rewrite /= encode_mat_prod. - apply /RlebP. - eapply Rle_trans. - - apply /RlebP. - apply matrix_norm_inf_sub_mult. - - rewrite -norm_inf_diag norm_inf_conj_half_roots Rmult_1_l. - rewrite norm_inf_peval_mat_conj_even_roots. - coq_lra. - Qed. - - - Lemma norm_inf_triang {n} (v1 v2 : 'rV[R[i]]_n) : - (norm_inf (v1 + v2) <= norm_inf v1 + norm_inf v2)%O. - Proof. - rewrite /norm_inf. - apply big_rec. - - unfold Order.le; simpl. - apply /RlebP. - erewrite <- addr0 at 1. - apply Rplus_le_compat; apply /RlebP; apply bigmax_normc_nneg. - - intros i x _ xn. - rewrite Order.TotalTheory.le_maxl. - generalize (normc_triang (v1 0 i) (v2 0 i)); intros. - rewrite mxE. - apply /andP; split; trivial. - unfold Order.le in *; simpl in *. - apply /RlebP. - move => /RlebP in H. - eapply Rle_trans. - apply H. - apply Rplus_le_compat; apply /RlebP; apply bigmaxr_le. - Qed. - - Lemma norm_inf_semi_multiplicative {n} (v1 v2 : 'rV[R[i]]_n) : - (norm_inf (map2_mx (fun (a b : R[i]) => a * b) v1 v2) <= norm_inf v1 * norm_inf v2)%O. - Proof. - rewrite /norm_inf. - apply big_rec. - - unfold Order.le, zero, mul; simpl. - apply /RlebP; apply Rmult_le_pos; apply /RlebP; apply bigmax_normc_nneg. - - intros i x _ xn. - rewrite Order.TotalTheory.le_maxl mxE ComplexField.Normc.normcM. - apply /andP; split; trivial. - clear x xn. - unfold Order.le; simpl. - apply /RlebP. - apply Rmult_le_compat; try apply normc_nnegR; apply /RlebP; apply bigmaxr_le. - Qed. - - Lemma norm_inf_pos_def {n} (v : 'rV[R[i]]_n) : - norm_inf v = 0 -> v = 0. - Proof. - rewrite /norm_inf => HH. - apply /matrixP => a b. - move: (ord1 a)->. - move: (bigmaxr_le v 0 (@normc _) b). - rewrite {}HH. - move/RlebP => HH. - rewrite [v 0 b]ComplexField.eq0_normC ?mxE//. - move: (normc_nnegR (v 0 b)) => HH2. - rewrite /zero /=. - f_equal. - now apply Rle_antisym. - Qed. - - Lemma normc_Rabs (r : R) : - normc (RtoC r) = Rabs r. - Proof. - rewrite /normc /RtoC (expr2 0) mulr0 addr0. - by rewrite ssrnum.Num.Theory.sqrtr_sqr. - Qed. - - Lemma mx_evalZ {n} (v : 'rV[R[i]]_n.+1) (r:R) p : - mx_eval v (r *: p) = (RtoC r) *: (mx_eval v p). - Proof. - apply matrixP => a b. - rewrite !mxE /scale /= scale_polyE. - rewrite rmorphismMP /= (map_polyC RtoC_rmorphism) /=. - by rewrite -hornerZ /= /scale /= scale_polyE. - Qed. - - (* following 4 lemmas show canon_norm_inf is a norm on quotient by x^+(2^n) + 1 *) - Lemma canon_norm_infZ n (r : R) (p : {poly R}) : - canon_norm_inf n (r *: p) = Rabs r * canon_norm_inf n p. - Proof. - by rewrite /canon_norm_inf !mx_evalZ norm_infZ normc_Rabs. - Qed. - - Lemma canon_norm_inf_nneg n (p : {poly R}) : - (zero R_ringType <= canon_norm_inf n p)%O. - Proof. - apply norm_inf_nneg. - Qed. - - Lemma canon_norm_inf_triang n (p q : {poly R}) : - (canon_norm_inf n (p + q) <= canon_norm_inf n p + canon_norm_inf n q)%O. - Proof. - rewrite /canon_norm_inf. - move: (raddfD (mx_eval_rmorphism (odd_nth_roots' n)) p q) => /= ->. - apply norm_inf_triang. - Qed. - - Lemma canon_norm_inf_semi_multiplicative n (p q : {poly R}) : - (canon_norm_inf n (p * q) <= canon_norm_inf n p * canon_norm_inf n q)%O. - Proof. - rewrite /canon_norm_inf rmorphM. - apply norm_inf_semi_multiplicative. - Qed. - - Lemma mx_eval_pow {n} (v : 'rV_n.+1) e : - mx_eval v 'X^e = map_mx (fun c => c ^ e) v. - Proof. - apply eq_map_mx. - intros ?. - by rewrite map_polyXn hornerXn /exprz. - Qed. - - - Lemma iter_max {disp : Datatypes.unit} {T : porderType disp} i (a b: T) : - i != 0%nat -> - ssrnat.iter i (Order.max a) b = Order.max a b. - Proof. - induction i => //=. - destruct i => // _. - rewrite IHi //=. - by rewrite Order.POrderTheory.max_maxxK. - Qed. - - Lemma canon_norm_inf_pow n e : - canon_norm_inf n 'X^e = 1. - Proof. - rewrite /canon_norm_inf mx_eval_pow /odd_nth_roots'. - unfold norm_inf. - under eq_big_seq. - { - intros. - rewrite !mxE. - rewrite /exprz. - rewrite pow_nth_root'; try lia. - rewrite normc_nth_root; try lia. - over. - } - simpl. - rewrite big_const_ord iter_max //. - rewrite /Order.max. - case: RltP => //. - rewrite /one/zero/=. - move/lt_IZR. - lia. - Qed. - - Lemma canon_norm_inf_C_pow n e c : - canon_norm_inf n (c *: 'X^e) = Rabs c. - Proof. - by rewrite canon_norm_infZ canon_norm_inf_pow mulr1. - Qed. - - Lemma canon_norm_inf_C n (c : R) : - canon_norm_inf n (polyC c) = Rabs c. - Proof. - rewrite -(canon_norm_inf_C_pow n 0 c). - f_equal. - by rewrite expr0 /scale /Lmodule.scale /= scale_polyE mulr1. - Qed. - - Lemma coef_norm1_C (c : R) : - coef_norm1 (polyC c) = Rabs c. - Proof. - rewrite /coef_norm1 size_polyC. - case: eqVneq; simpl. - - rewrite big_ord0 => ->. - by rewrite Rabs_R0. - - by rewrite big_ord1 coefC. - Qed. - - Lemma size_poly_def (p : {poly R}) : - size (\sum_(i < size p) p`_i *: 'X^i) = size p. - Proof. - case: (eqVneq p 0); intros. - - by rewrite e poly_size_0 big_ord0 poly_size_0. - - rewrite -poly_def size_poly_eq //. - by rewrite -lead_coefE lead_coef_eq0. - Qed. - - Lemma coef_norm1_poly_def (p : {poly R}) : - coef_norm1 (\sum_(i < (size p)) p`_i *: 'X^i) = - \sum_(i < (size p)) Rabs p`_i. - Proof. - rewrite /coef_norm1 size_poly_def. - apply eq_big => //= i _. - f_equal. - by rewrite -poly_def coef_poly ltn_ord. - Qed. - - Lemma canon_norm_inf_poly_def n m (p : {poly R}) : - (canon_norm_inf n (\sum_(i < m) p`_i *: 'X^i) <= - \sum_(i < m) canon_norm_inf n (p`_i *: 'X^i))%O. - Proof. - induction m. - - rewrite !big_ord0 canon_norm_inf_C Rabs_R0. - apply Order.POrderTheory.le_refl. - - - move: (@big_nat_recr {poly R} 0 (add_monoid _) - m - 0 - (fun i => p`_i *: 'X^i) - (leq0n _) - ). - rewrite !big_mkord => ->. - move: (@big_nat_recr R 0 (@add_monoid _) - m - 0 - (fun i => canon_norm_inf n (p`_i *: 'X^i)) - (leq0n _) - ). - rewrite !big_mkord => -> /=. - eapply Order.POrderTheory.le_trans. - + apply canon_norm_inf_triang. - + by apply ssrnum.Num.Theory.lerD. - Qed. - - Lemma canon_norm_inf_le_norm1 n (p : {poly R}) : - (canon_norm_inf n p <= coef_norm1 p)%O. - Proof. - rewrite -(coefK p) poly_def coef_norm1_poly_def. - eapply Order.POrderTheory.le_trans. - apply canon_norm_inf_poly_def. - rewrite /Order.le /Order.POrder.le /=. - apply /RlebP. - right. - apply eq_big_seq => ??. - by rewrite canon_norm_inf_C_pow. - Qed. - - Lemma canon_norm_inf_val n (p : {poly R}) (i : 'I_(2^n-1).+1) : - (normc ((map_poly RtoC p).[odd_nth_roots' n 0 i]) <= canon_norm_inf n p)%O. - Proof. - rewrite /canon_norm_inf /norm_inf /=. - generalize (@bigmaxr_le (2^n-1).+1 (odd_nth_roots' n) 0 (fun c => normc (map_poly RtoC p).[c]) i) => HH. - eapply Order.POrderTheory.le_trans. - - rewrite /Order.le/=. - apply HH. - - rewrite /Order.le/=. - apply /RlebP. - right. - apply eq_big_seq => ??. - f_equal. - by rewrite /mx_eval !mxE. - Qed. - - Lemma peval_mx_eval {n} (v :'rV[R[i]]_n.+1) (p : {poly R}) : - size p <= n.+1 -> - let pmat := peval_mat v in - let pC := map_poly RtoC p in - mx_eval v p = (pmat *m (poly_rV pC)^T)^T. - Proof. - intros. - rewrite /mx_eval /peval_mat /map_mx trmx_mul trmxK /poly_rV /= /mulmx. - apply matrixP => j k. - rewrite !mxE /= horner_coef -map_RtoC_size. - unfold pmat, peval_mat, pC. - under [RHS]eq_bigr do rewrite !mxE. - case : (eqVneq (size p) n.+1); intros. - - rewrite e /=. - apply eq_big_seq => ??. - by rewrite fintype.ord1. - - transitivity (\sum_(i0 < (size p + (n.+1-size p))%nat) - (map_poly (aR:=R_ringType) (rR:=ComplexField.complex_ringType R_fieldType) RtoC p)`_i0 * v 0 k ^+ i0); - [| by have ->: (size p + (n.+1 - size p) = n.+1)%nat by lia]. - rewrite big_split_ord /=. - rewrite !fintype.ord1. - suff ->: ( \sum_(i0 < n.+1 - size p) - (map_poly (aR:=R_ringType) (rR:=ComplexField.complex_ringType R_fieldType) RtoC p)`_ - (size p + i0) * v 0 k ^+ (size p + i0) = 0). - { by rewrite addr0. } - - under eq_bigr => si _. - { rewrite nth_default ?mul0r. - - over. - - rewrite -map_RtoC_size. - lia. - } - by rewrite big_const_seq iter_addr_0 mul0rn. - Qed. - -Lemma decode_encode_scalar_mx' (n : nat): - let pmat := (peval_mat (odd_nth_roots' (S n))) in - pmat *m (conj_mat (pmat^T)) = scalar_mx (2^S n)%:R. -Proof. - Proof. - apply/matrixP. - move/matrixP: (decode_encode_scalar_mx n) => H i j. - have eqq: (sval (pow2_S n.+1)).+1 = (2^n.+1)%nat by (simpl; lia). - move: (H (cast_ord eqq i) (cast_ord eqq j)). - rewrite !mxE /= => <-. - rewrite (big_ord_widen_leq (2^n.+1)%N); try lia. - apply eq_big. - - move=> [??] /=. - by lia. - - intros. - by rewrite !mxE inordK. - Qed. - - Lemma encode_mat_norm_inf' (n : nat) : - let pmat' := peval_mat (odd_nth_roots' (S n)) in - let encmat' := (conj_mat (pmat'^T)) in - Rleb (matrix_norm_inf encmat') (2^S n)%:R. - Proof. - generalize (encode_mat_norm_inf n); intros. - unfold matrix_norm_inf in *; simpl in *. - rewrite (big_ord_widen_leq (2^n.+1)%N); try lia. - apply /RlebP. - move /RlebP in H. - eapply Rle_trans; cycle 1. - apply H. - right. - apply eq_big. - - intros ?. - destruct x. - simpl. - lia. - - intros. - rewrite (big_ord_widen_leq (2^n.+1)%N); try lia. - apply eq_big. - + intros ?. - destruct x. - simpl. - lia. - + intros. - rewrite /odd_nth_roots /odd_nth_roots' !mxE !inordK //. - Qed. - - - Lemma encmat_pmat (n : nat) : - let pmat' := peval_mat (odd_nth_roots' (S n)) in - let encmat := (conj_mat (pmat'^T)) in - pmat' *m ((RtoC (inv (2^S n)%:R)) *: encmat) = scalar_mx 1. - Proof. - intros. - rewrite -scalemxAr /encmat /pmat' decode_encode_scalar_mx'. - apply /matrixP => i j. - rewrite !mxE mulrnAr -RtoCR -rmorphM /= mulrC divrr // unitfE. - by rewrite natmul0eq pow2n0. - Qed. - - Lemma invmx_comm (n : nat) (A B : 'M[R[i]]_n) : - A *m B = scalar_mx 1 -> - B *m A = scalar_mx 1. - Proof. - intros. - destruct (mulmx1_unit H). - generalize (mulmxV H1); intros. - generalize H2; intros. - apply (f_equal (fun z => A *m z)) in H2. - rewrite mulmxA H mul1mx mulmx1 in H2. - by rewrite H2 in H3. - Qed. - - Lemma scalar_prod_comm (n : nat) (A B : 'M[R[i]]_n) (c : R[i]) : - c != 0 -> - A *m B = scalar_mx c -> - B *m A = scalar_mx c. - Proof. - intros. - apply (f_equal (fun z => (inv c) *: z)) in H0. - rewrite scalemxAl scale_scalar_mx mulrC divff // in H0. - apply invmx_comm in H0. - rewrite -scalemxAr in H0. - apply (f_equal (fun z => c *: z)) in H0. - rewrite scalerA divff // scalemx1 in H0. - by rewrite -(scale1mx (B *m A)) -H0. - Qed. - - Lemma encmat_pmat_alt (n : nat) : - let pmat' := peval_mat (odd_nth_roots' (S n)) in - let encmat := (conj_mat (pmat'^T)) in - ((RtoC (inv (2^S n)%:R)) *: encmat) *m pmat' = scalar_mx 1. - Proof. - intros. - apply invmx_comm. - apply encmat_pmat. - Qed. - -Lemma decode_encode_off_diag_T (n : nat): - let pmat := (peval_mat (odd_nth_roots' (S n))) in - forall n1 n2, - n1 <> n2 -> - H_inner_prod (col n1 pmat)^T (col n2 pmat)^T = C0. -Proof. - intros. - rewrite !tr_col -H_inner_prod_mat trmxK. - generalize (encmat_pmat_alt n); intros. - simpl in H0. - apply (f_equal (fun m => trmx ((RtoC (2 ^ n.+1)%:R) *: m))) in H0. - rewrite scalemxAl scalerA in H0. - replace ((RtoC (2 ^ n.+1)%:R * RtoC (2 ^ n.+1)%:R^-1)) with (RtoC 1) in H0. - - rewrite trmx_mul scale1r conj_transpose trmxK scalemx1 in H0. - rewrite /pmat H0 tr_scalar_mx mxE. - case: (eqVneq n1 n2); intros. - + by rewrite e in H. - + by rewrite /= mulr0n. - - rewrite -rmorphM /= divff // natmul0eq. - lia. -Qed. - - Lemma encmat_pmat_pvec (n : nat) (p : {poly R}) : - let pmat' := peval_mat (odd_nth_roots' (S n)) in - let encmat := (conj_mat (pmat'^T)) in - let pvec := (poly_rV (d := (sval (pow2_S n.+1)).+1) - (map_poly (aR:=R_ringType) - (rR:=ComplexField.complex_ringType - R_fieldType) RtoC p))^T in - (((RtoC (inv (2^S n)%:R)) *: encmat) *m pmat') *m pvec = pvec. - Proof. - by rewrite /= encmat_pmat_alt mul1mx. - Qed. - - Lemma big_max_split {k2 : nat} (k1 : nat) (F : 'I_k2 -> R) : - \big[Order.max/0]_(j < k2) F j = - Order.max - (\big[Order.max/0]_(j < k2 | true && (j < k1)%N) F j) - (\big[Order.max/0]_(j < k2 | true && ~~(j < k1)) F j). - Proof. - rewrite -Order.TotalTheory.bigmaxID. - by apply eq_bigr. - Qed. - -Lemma big_max_nneg_with_trailing_zeros {k1 k2} (le12: k1 <= k2) (F: 'I_k2 -> R) : - (forall i, Rle 0 (F i)) -> - (forall i: 'I_k2 , k1 <= i -> F i = 0%R) -> - \big[Order.max/0]_(j < k2) F j = \big[Order.max/0]_(j < k1) F (widen_ord le12 j). - Proof. - intros Fnneg Ftrail0. - rewrite (big_max_split k1). - assert (\big[Order.max/0]_(j < k2 | true && ~~ (j < k1)) F j = 0). - { - under eq_bigr. - intros. - rewrite Ftrail0; try lia. - over. - rewrite big_const_seq iter_fix // -RmaxE /zero/= Rmax_left //; coq_lra. - } - rewrite H -RmaxE Rmax_left. - destruct k1. - - rewrite big_ord0. - pose G : ('I_k2 -> R_orderType) := fun=> 0%R. - assert (\big[Order.max/0]_(j < k2 | true && (j < 0)) F j = - \big[Order.max/0]_(j < k2 | true && (j < 0)) G j). - { - apply congr_big; trivial. - intros ??. - lia. - } - rewrite H0 /G. - rewrite big_const_seq iter_fix // -RmaxE /zero/= Rmax_left //; coq_lra. - - rewrite [RHS](big_ord_widen_leq k2); trivial. - apply eq_bigr. - intros. - f_equal. - apply ord_inj => /=. - by rewrite inordK. - - pose G : ('I_k2 -> R_orderType) := fun=> 0%R. - assert ((\big[Order.max/0]_(j < k2 | true && (j < k1)%N) (G j)) = 0)%R. - { - rewrite /G big_const_seq iter_fix // -RmaxE /zero/= Rmax_left //; coq_lra. - } - rewrite -{1}H0. - apply /RleP. - apply (@Order.TotalTheory.le_bigmax2 _ R_orderType). - intros. - rewrite /G. - apply /RleP. - apply Fnneg. - Qed. - - - - - - Lemma coef_maxnorm_pvec n (p : {poly R}) : - size p <= 2^n.+1 -> - let pvec := (poly_rV (d := (sval (pow2_S n.+1)).+1) - (map_poly (aR:=R_ringType) - (rR:=ComplexField.complex_ringType - R_fieldType) RtoC p))^T in - coef_maxnorm p = cvec_norm_inf pvec. - Proof. - intros. - rewrite /coef_maxnorm /cvec_norm_inf /norm_inf. - assert ((sval (pow2_S n.+1)).+1 = (2^n.+1)%N). - { - simpl. - lia. - } - case : (eqVneq (size p) (sval (pow2_S n.+1)).+1); intros. - - rewrite e /=. - apply eq_big_seq => k HH. - rewrite /pvec !mxE. - rewrite /map_poly coef_poly /=. - case: ltP; rewrite normc_Rabs // => kbig. - rewrite nth_default //. - lia. - - have le1: (size p <= (sval (pow2_S n.+1)).+1) by lia. - rewrite (big_max_nneg_with_trailing_zeros le1). - + apply eq_bigr => j _. - rewrite /pvec /poly_rV !mxE. - rewrite /= /map_poly coef_poly. - destruct j. - simpl in i0. - by rewrite /= i0 normc_Rabs. - + intros. - apply/RleP. - apply normc_nneg. - + intros. - rewrite /pvec /= /poly_rV. - rewrite !mxE /= map_polyE. - rewrite nth_default. - * by rewrite ComplexField.Normc.normc0. - * eapply leq_trans. - -- apply size_Poly. - -- by rewrite size_map. - Qed. - - Lemma canon_norm_inf_pvec n (p : {poly R}) : - size p <= 2^n -> - let pmat' := peval_mat (odd_nth_roots' n) in - let pvec := (poly_rV (d := (sval (pow2_S n)).+1) - (map_poly (aR:=R_ringType) - (rR:=ComplexField.complex_ringType - R_fieldType) RtoC p))^T in - canon_norm_inf n p = cvec_norm_inf (pmat' *m pvec). - Proof. - intros. - rewrite /canon_norm_inf /cvec_norm_inf /norm_inf. - apply eq_big; trivial; intros. - f_equal. - rewrite /odd_nth_roots' /pmat' /pvec /peval_mat !mxE /map_poly horner_poly. - rewrite /odd_nth_roots'. - under [RHS]eq_bigr do rewrite !mxE mulrC coef_poly. - simpl. - case : (eqVneq (size p) (2^n-1).+1); intros. - - rewrite e. - apply eq_big_seq => k HH. - f_equal. - assert (k < (2 ^ n - 1).+1). - { - by destruct k; simpl. - } - by rewrite H1. - - transitivity (\sum_(i1 < size p + ((2^n-1).+1-size p)%nat) - (if i1 < size p then RtoC p`_i1 else 0) * - nth_root (2 * i + 1) (2 ^ n.+1) ^+ i1); - [| by have ->: (size p + ((2^n-1).+1 - size p) = (2^n-1).+1)%nat by lia]. - rewrite big_split_ord /=. - assert (\sum_(i1 < (2 ^ n - 1).+1 - size p) - (if size p + i1 < size p then RtoC p`_(size p + i1) else 0) * - nth_root (2 * i + 1) (2 ^ n.+1) ^+ (size p + i1) = 0). - { - under eq_bigr => si _. - { - have ->: ((size p + si < size p) = false) by lia. - rewrite mul0r. - over. - } - by rewrite big_const_seq iter_addr_0 mul0rn. - } - rewrite H1 addr0. - by apply eq_big => // [[? /= ->]]. - Qed. - - Lemma matrix_norm_inf_pmat_inv n : - let pmat' := peval_mat (odd_nth_roots' (S n)) in - let encmat := (conj_mat (pmat'^T)) in - Rle (matrix_norm_inf ((RtoC (inv (2^S n)%:R)) *: encmat)) 1. - Proof. - generalize (encode_mat_norm_inf' n); intros. - simpl in H. - move /RlebP in H. - rewrite matrix_norm_inf_scale normc_Rabs. - apply Rmult_le_compat_l with (r := Rabs (2 ^ n.+1)%:R^-1) in H. - - eapply Rle_trans. - apply H. - right. - assert (Rlt 0 (2 ^ n.+1)%:R). - { - rewrite -INRE. - apply lt_0_INR. - lia. - } - rewrite Rabs_right. - + rewrite /inv/= RmultRinvx // unitrE divff //. - rewrite /zero/=. - apply/eqP. - apply Rgt_not_eq. - apply H0. - + left. - rewrite -RinvE. - * apply Rinv_0_lt_compat; trivial. - * rewrite /zero/=. - apply/eqP. - apply Rgt_not_eq. - apply H0. - - apply Rabs_pos. - Qed. - - Lemma Rmult_le_1 (r1 r2 : R) : - Rle r1 1 -> - Rle 0 r2 -> - Rle (r1 * r2) r2. - Proof. - intros. - apply Rmult_le_compat_r with (r := r2) in H; trivial. - by rewrite Rmult_1_l in H. - Qed. - - Lemma matrix_norm_inf_pos {n m} (A : 'M[R[i]]_(n,m)) : - Rle 0 (matrix_norm_inf A). - Proof. - rewrite /matrix_norm_inf. - apply /RlebP. - apply bigmax_nneg. - intros. - apply sum_nneg. - intros. - apply /RlebP. - apply normc_nnegR. - Qed. - - Theorem coef_maxnorm_le_canon_norm_inf n (p : {poly R}) : - size p <= 2^n.+1 -> - Rle (coef_maxnorm p) (canon_norm_inf n.+1 p). - Proof. - intros. - rewrite (coef_maxnorm_pvec n) //. - rewrite canon_norm_inf_pvec //. - rewrite -{1}encmat_pmat_pvec /cvec_norm_inf. - rewrite !mat_vec_norm_inf !trmxK. - eapply Rle_trans. - - rewrite -mulmxA. - apply /RlebP. - apply matrix_norm_inf_sub_mult. - - generalize (matrix_norm_inf_pmat_inv n); intros. - simpl in H0. - apply Rmult_le_1; trivial. - apply matrix_norm_inf_pos. - Qed. - - Theorem canon_norm_inf_bounds n (p : {poly R}) : - size p <= 2^n.+1 -> - (coef_maxnorm p <= canon_norm_inf n.+1 p <= coef_norm1 p)%O. - Proof. - intros. - apply /andP. - split. - - apply /RleP. - by apply coef_maxnorm_le_canon_norm_inf. - - apply canon_norm_inf_le_norm1. - Qed. - - Lemma canon_norm_zero_mod_qpoly n (p : {poly R}) : - canon_norm_inf n p = 0 -> - Pdiv.Ring.rmodp (R:=R_ringType) p ('X^(2 ^ n) + 1%:P) = 0. - Proof. - intros. - apply odd_nth_roots_minpoly_mult_R. - intros. - unfold root. - apply /eqP. - generalize (canon_norm_inf_val n p); intros. - destruct i. - assert (m < (2^n-1).+1) by lia. - specialize (H0 (Ordinal H1)). - rewrite H in H0. - apply ComplexField.Normc.eq0_normc. - apply Order.POrderTheory.le_anti. - apply /andP; split. - - replace (odd_nth_roots n 0 (Ordinal (n:=2^n) i)) with - (odd_nth_roots' n 0 (Ordinal (n:=(2^n-1).+1) H1)). - + apply H0. - + by rewrite /odd_nth_roots' /odd_nth_roots !mxE. - - apply normc_nneg. - Qed. - -(* following only holds on quotient ring by x^+(2^n) + 1 - Lemma canon_norm_inf_pos_def n p : - canon_norm_inf n p = 0 -> p = 0. -*) - - - Lemma normc_conj_mul (x y : R[i]) : - normc (x * y) = normc (x * (conjc y)). - Proof. - by rewrite !ComplexField.Normc.normcM (normc_conj y). - Qed. - - Lemma normc_conj_add (r : R) (x y : R[i]) : - normc (x + y) = normc (conjc x + conjc y). - Proof. - by rewrite -rmorphD normc_conj. - Qed. - - Lemma normc_conj_exp (x : R[i]) n : - normc (x ^+ n) = normc ((conjc x) ^+ n). - Proof. - by rewrite -rmorphXn normc_conj. - Qed. - - Lemma RtoC1 : RtoC 1 = 1. - Proof. - by []. - Qed. - - Lemma RtoC0E (c:R) : (RtoC c == 0) = (c == 0). - Proof. - by rewrite /RtoC !eqE /= !eqE /= eqxx !andbT. - Qed. - - Lemma RtoC_real a : RtoC a \is ssrnum.Num.real. - Proof. - by rewrite complex_real. - Qed. - - Lemma conjc_id (a:R[i]) : (a^* = a)%C <-> (a \is ssrnum.Num.real). - Proof. - split. - - move/Cconj_im_0 => eqq. - apply/ssrnum.Num.Theory.Creal_ImP. - by rewrite -complexIm /= eqq. - - rewrite /in_mem/mem /= /Order.le /=. - case: a => ra ia /= /orP-[/andP-[/eqP-> _]|/andP-[/eqP<- _]] - ; by rewrite oppr0. - Qed. - - Lemma conjc_RtoC a : ((RtoC a)^* )%C = RtoC a. - Proof. - by rewrite /RtoC /= oppr0. - Qed. - - Lemma rpoly_eval_conj (p : {poly R}) (x : R[i]) : - let pc := map_poly RtoC p in - pc.[x]^*%C = pc.[x^*%C]. - Proof. - case: p => l llast. - rewrite /horner /= !map_polyE /=. - have/PolyK->: seq.last (RtoC 1) (seq.map RtoC l) != 0 - by rewrite seq.last_map RtoC0E. - move => {llast}. - elim: l => /=. - - by rewrite oppr0. - - move=> a l <-. - by rewrite rmorphD rmorphM /RtoC /= oppr0. - Qed. - - Lemma rpoly_eval_conj_R (p : {poly R}) (x : R[i]) : - let pc := map_poly RtoC p in - pc.[x] = pc.[x]^*%C -> - pc.[x] = pc.[x^* %C]. - Proof. - move=> /= -> . by rewrite rpoly_eval_conj. - Qed. - - Lemma normc_conj_poly (p : {poly R}) (x : R[i]) : - let pc := map_poly RtoC p in - normc (pc.[x]) = normc (pc.[x^*%C]). - Proof. - by rewrite /= -rpoly_eval_conj normc_conj. - Qed. - -End norms. - -Lemma pmat_normc_1 (n : nat) : - let pmat := peval_mat (odd_nth_roots (S n)) in - forall i j, - Cmod (pmat i j) = 1. -Proof. - simpl; intros. - unfold peval_mat, odd_nth_roots. - rewrite !mxE. - destruct (pow2_S (n.+2)). - move => /eqP in i0. - by rewrite i0 exp_nth_root nth_root_Cmod. -Qed. - -Definition coefE := - (coef0, coef1, coefC, coefX, coefXn, coef_sumMXn, - coefZ, coefMC, coefCM, coefXnM, coefMXn, coefXM, coefMX, coefMNn, coefMn, - coefN, coefB, coefD, coef_even_poly, coef_odd_poly, - coef_take_poly, coef_drop_poly, coef_cons, coef_Poly, coef_poly, - coef_deriv, coef_nderivn, coef_derivn, coef_map, coef_sum, - coef_comp_poly_Xn, coef_comp_poly). - -Lemma Cmod_normc (c : R[i]) : - Cmod c = ComplexField.Normc.normc c. -Proof. - unfold Cmod, ComplexField.Normc.normc. - destruct c. - rewrite RsqrtE//. - apply ssrnum.Num.Theory.addr_ge0; apply ssrnum.Num.Theory.sqr_ge0. -Qed. - -Lemma Cmod_triang (x y : R[i]) : - Rle (Cmod (x + y)) (Cmod x + Cmod y). -Proof. - rewrite !Cmod_normc. - by apply normc_triangR. -Qed. - -Lemma Cmod_mul (c1 c2 : R[i]): - Cmod (c1 * c2) = (Cmod c1) * (Cmod c2). -Proof. - by rewrite !Cmod_normc ComplexField.Normc.normcM. -Qed. - -Lemma cmod_1_delta (c1 c2 : R[i]) (delta : R) : - Cmod c1 = 1 -> - Rlt (Cmod c2) delta -> - Rlt (Cmod (c1 * c2)) delta. -Proof. - intros. - by rewrite Cmod_mul H mul1r. -Qed. - -Lemma root_eval_bound_cpoly (c : R[i]) (p : {poly R[i]}) (δ : R) : - Rle 0 δ -> - Cmod c = 1 -> - (forall i, Rle (Cmod (p`_ i)) δ) -> - Rle (Cmod (p.[c])) (δ *+ (seq.size p)). -Proof. - intros δnneg Cnorm1 coeffsmall. - rewrite -{1}[p]polyseqK horner_Poly. - move: (polyseq p) coeffsmall => {p}. - elim. - - move=> _/=. - rewrite mulr0n expr0n /= !mulr0n Rplus_0_l sqrt_0. - by apply Rle_refl. - - move=> a l IHl coeffsmall /=. - rewrite mulrS [δ + _]addrC. - rewrite !Cmod_normc. - eapply Rle_trans; [apply normc_triangR | ]. - rewrite ComplexField.Normc.normcM. - rewrite -!Cmod_normc Cnorm1 mulr1. - rewrite /add /=. - apply Rplus_le_compat. - + apply IHl. - move=>i. - by move: coeffsmall => /(_ (i.+1))/=. - + by move: coeffsmall => /(_ O) /= //. -Qed. - -Lemma RtoC_real_complex (r : R) : - RtoC r = real_complex _ r. -Proof. - reflexivity. -Qed. - -Lemma Cmod_Rabs (r : R) : - Cmod (RtoC r) = Rabs r. -Proof. - rewrite RtoC_real_complex /Cmod /=. - rewrite (expr2 0) mulr0 Rplus_0_r. - generalize (pow2_inv r (r^+ 2)); intros. - by rewrite H. -Qed. - -Lemma Cmod_0 : - Cmod 0 = 0. -Proof. - unfold Cmod. simpl. - by rewrite !expr2 !mulr0 Rplus_0_r sqrt_0. -Qed. - -Lemma root_eval_bound (c : R[i]) (p : {poly R}) (δ : R) : - Rle 0 δ -> - Cmod c = 1 -> - (forall i, Rle (Rabs (p`_ i)) δ) -> - Rle (Cmod (map_poly RtoC p).[c]) (δ *+ seq.size p). -Proof. - intros. - generalize (root_eval_bound_cpoly c (map_poly RtoC p) δ H H0); intros. - rewrite (size_map_poly RtoC_rmorphism p) in H2. - apply H2; intros. - rewrite coefE. - destruct (i < seq.size p)%N. - - by rewrite Cmod_Rabs. - - by rewrite Cmod_0. -Qed. - -Lemma Cmod_sum (n : nat) (cl : 'I_n -> R[i]) : - Rle (Cmod (\sum_i cl i)) (\sum_i (Cmod (cl i))). -Proof. - simpl. - elim: (index_enum (ordinal_finType n)) => /=. - - rewrite !big_nil /=. - rewrite !expr2 !mulr0 Rplus_0_r sqrt_0. - by apply Rle_refl. - - move=> a l IHl. - rewrite !big_cons /=. - eapply Rle_trans; [ apply Cmod_triang |]. - apply Rplus_le_compat; trivial. - by apply Rle_refl. -Qed. - -Lemma const_sum (n : nat) (δ : R) : - \sum_(i < n) δ = n%:R * δ. -Proof. - rewrite big_const_ord. - induction n. - + by rewrite /= mul0r. - + by rewrite /= IHn mulrS mulrDl mul1r. -Qed. - -Lemma leq_sum_R [I : Type] (r : list I) [P : pred I] [E1 E2 : I -> R] : - (forall i : I, P i -> Rle (E1 i) (E2 i)) -> - Rle (\sum_(i <- r | P i) E1 i) (\sum_(i <- r | P i) E2 i). -Proof. - apply big_ind2. - - by right. - - by apply Rplus_le_compat. -Qed. - -Lemma bounded_sum (n : nat) (cl : 'I_n -> R) (δ : R) : - Rle 0 δ -> - (forall i, Rle (cl i) δ) -> - Rle (\sum_i cl i) (n%:R * δ). -Proof. - intros. - apply Rle_trans with (r2 := \sum_(i < n) δ). - - by apply leq_sum_R. - - right. - by apply const_sum. - Qed. - -Lemma Cabs_sum_bound (n : nat) (cl : 'I_n -> R[i]) (δ : R) : - Rle 0 δ -> - (forall i, Rle (Cmod (cl i)) δ) -> - Rle (Cmod (\sum_i cl i)) (n%:R * δ). -Proof. - intros. - eapply Rle_trans. - apply Cmod_sum. - by apply bounded_sum. -Qed. - -Lemma decode_delta (n : nat) (cl : 'cV[R[i]]_(2^(S n))) (δ : R) : - Rle 0 δ -> - let pmat := peval_mat (odd_nth_roots (S n)) in - (forall i, Rle (Cmod (cl i 0)) δ) -> - forall i, Rle (Cmod ((pmat *m cl) i 0)) ((2^S n)%:R * δ). -Proof. - simpl; intros. - rewrite !mxE. - apply Cabs_sum_bound; trivial. - intros. - by rewrite Cmod_mul pmat_normc_1 mul1r. - Qed. - - Lemma mul_dvdn_l (x d1 d2:nat) : - (d1 * d2 %| x)%N -> (d1 %| x)%N. - Proof. - eapply dvdn_trans. - apply dvdn_mulr. - by apply dvdnn. - Qed. - - Lemma mul_dvdn_r (x d1 d2:nat) : - (d1 * d2 %| x)%N -> (d2 %| x)%N. - Proof. - rewrite mulnC. - by apply mul_dvdn_l. - Qed. - - Lemma modn_muln (x y b1 b2:nat) : - x == y %[mod b1 * b2] -> x == y %[mod b1]. - Proof. - wlog le_yx : x y / y <= x; last by (rewrite !eqn_mod_dvd //; apply mul_dvdn_l). - by have [?|/ltnW ?] := leqP y x; last rewrite !(eq_sym (x %% _)%N); apply. - Qed. - - Section unity. - Context {T : comRingType} - (z : T). - - Lemma two_pow_prim_root_alt (n:nat) : - z ^+ (2^n) <> 1 -> - z ^+ (2^n.+1) = 1 -> - primitive_root_of_unity (2^(n.+1)) z. - Proof. - intros zpow_n1 zpow1. - assert (root_of_unity (2^(n.+1)) z). - { - by apply /unity_rootP. - } - destruct (@prim_order_exists _ (2^(n.+1)) z). - - destruct (pow2_S (n.+1)). - move=> /eqP in i. - by rewrite i. - - by apply /unity_rootP. - - assert (exists (k:nat), x = expn 2 k). - { - move=> /prime.dvdn_pfactor in i0. - destruct i0. - by reflexivity. - by exists x0. - } - move: H0 i0 i => [x0 ->] i0 i. - have HH: x0 = n.+1. - { - move: i0. - rewrite div.dvdn_Pexp2l; try lia. - rewrite leq_eqVlt => /orP-[/eqP//|x0lt]. - assert (HH:expn 2 n = muln (expn 2 x0) (expn 2 (n - x0))). - { - rewrite -expnD. - f_equal. - lia. - } - by rewrite HH exprM (prim_expr_order i) Theory.expr1n in zpow_n1. - } - by rewrite HH in i. - Qed. - - Lemma two_pow_prim_root (n:nat) : - -(one T) <> (one T) -> - z ^+ (2^n) = -1 -> - primitive_root_of_unity (2^(n.+1)) z. - Proof. - intros onem1 zpowm1. - apply two_pow_prim_root_alt. - - by rewrite -zpowm1 in onem1. - - by rewrite expnSr exprM zpowm1 expr2 mulrNN mulr1. - Qed. - - Lemma unity_gcd e1 e2 : - z^+e1 = 1 -> - z^+e2 = 1 -> - z^+(gcdn e1 e2) = 1. - Proof. - intros. - destruct e2. - - by rewrite gcdn0. - - assert (0 < e2.+1) by lia. - destruct (egcdnP e1 H1). - apply (f_equal (fun y => y^+kn)) in H. - rewrite -exprM mulnC expr1n in H. - apply (f_equal (fun z => z^+km)) in H0. - by rewrite -exprM mulnC e exprD expr1n H mul1r gcdnC in H0. - Qed. - -Lemma modn_sub i j m : - i >= j -> - (i == j %[mod m]) = (i - j == 0 %[mod m]). -Proof. - move/eqn_mod_dvd->. - by rewrite mod0n. -Qed. - -Lemma modn_sub_iff i j m : - i >= j -> - i = j %[mod m] <-> i - j = 0 %[mod m]. -Proof. - move/modn_sub=>eqq. - split; move/eqP - ; [rewrite eqq | rewrite -eqq] - ; by move/eqP. -Qed. - -Lemma prim_root_inv (k n : nat) : - primitive_root_of_unity (n.+1) z -> - k < n.+1 -> - (z^+k) * (z ^+ (n.+1 - k)) = 1. -Proof. - intros. - rewrite -exprD. - replace (k + (n.+1-k))%N with (n.+1)%N by lia. - by apply prim_expr_order. -Qed. - -Lemma prim_root_inv' (k n : nat) : - n != 0%N -> - primitive_root_of_unity n z -> - k < n -> - (z^+k) * (z ^+ (n - k)) = 1. -Proof. - intros. - rewrite -exprD. - replace (k + (n-k))%N with (n)%N by lia. - by apply prim_expr_order. -Qed. - -Lemma prim_root_pow_unique (k1 k2 n : nat) : - primitive_root_of_unity n z -> - z ^+ k1 = z^+ k2 <-> k1 = k2 %[mod n]. -Proof. - intros. - generalize (eq_prim_root_expr H k1 k2); intros. - split; intros. - - move /eqP in H1. - rewrite H0 in H1. - apply (eqP H1). - - move /eqP in H1. - rewrite -H0 in H1. - apply (eqP H1). -Qed. - -Lemma two_pow_prim_root_inv (k n : nat) : - primitive_root_of_unity (2^n.+1) z -> - k < 2^n.+1 -> - (z^+k) * (z ^+ (2^n.+1 - k)) = 1. -Proof. - intros. - apply prim_root_inv'; lia. -Qed. - -Lemma prim_root_pow_sqr (k n : nat) : - n != 0%N -> - primitive_root_of_unity (2*n)%N z -> - (z^+k)^+2 = 1 -> - k = 0 %[mod n]. -Proof. - intros. - rewrite -exprM mulnC in H1. - generalize (prim_root_pow_unique (2*k) 0%N (2*n)%N H0); intros. - rewrite expr0 in H2. - assert (0 < 2*n) by lia. - rewrite H2 -muln_modr (modn_small H3) in H1. - assert (k %% n = 0)%N by lia. - by rewrite H4 mod0n. -Qed. - -Lemma zero_modn_mod2n (k n : nat) : - 0 < n -> - k = 0 %[mod n] -> - k = 0 %[mod 2*n] \/ k = n %[mod 2*n]. -Proof. - move=> npos /eqP. - move: (dvdn_eq n k). - rewrite (modn_small npos) /dvdn => -> eqq1. - - have ->: (k = k%/n * n)%N. - { - symmetry. - by apply /eqP. - } - rewrite -muln_modl. - have [-> | eqq]: (((k %/n) %% 2 = 0) \/ ((k %/n) %% 2 = 1))%N. - { - have: ((k %/ n) %% 2)%N < 2 by by apply ltn_pmod. - lia. - } - - left; by rewrite mod0n mul0n. - - right. - rewrite -{3}(mul1n n) eqq -muln_modl. - lia. -Qed. - -Lemma two_pow_prim_root_m1 (k n : nat) : - primitive_root_of_unity (2^n.+1) z -> - -(one T) <> (one T) -> - z^+k = -1 -> - k = 2^n %[mod 2^n.+1]. - Proof. - intros. - assert (2^n != 0)%N by lia. - rewrite expnS in H. - assert (z ^+ k ^+ 2 = 1). - { - by rewrite H1 expr2 mulrNN mulr1. - } - generalize (prim_root_pow_sqr k (2^n) H2 H H3); intros. - assert (k <> 0 %[mod 2^n.+1]). - { - unfold not; intros. - rewrite -expnS in H. - generalize (prim_root_pow_unique k 0 (2^n.+1) H); intros. - rewrite expr0 H1 in H6. - by rewrite -H6 in H5. - } - clear H H0 H1 H3 z T. - assert (0 < 2^n)%N by lia. - generalize (zero_modn_mod2n k (2^n) H H4); intros. - rewrite expnS. - rewrite expnS in H5. - by destruct H0. - Qed. - -Lemma two_pow_prim_root_m1_alt (n : nat) : - primitive_root_of_unity (2^n.+1) z -> - -(one T) <> (one T) -> - z^+(2^n) <> -1 -> - not (exists k, z^+k = -1). -Proof. - intros. - unfold not; intros. - destruct H2. - generalize (two_pow_prim_root_m1 x n H H0 H2); intros. - generalize (prim_expr_mod H); intros. - by rewrite -H4 H3 H4 in H2. -Qed. - - Lemma odd_pow_prim_root (n:nat) : - z ^+ (2^n) = -1 -> - forall j, - odd j -> - ((z ^+ j) ^+ (2^n)) = -1. - Proof. - intros. - by rewrite exprAC H -signr_odd H0 /= expr1. - Qed. - - Lemma gcd_odd_pow2 j n : - odd j -> - (div.gcdn j (2 ^ (S n)) = 1)%N. - Proof. - intros. - generalize (div.coprime2n j); intros. - assert (div.gcdn j 2 = 1%N). - { - rewrite H in H0. - unfold div.coprime in H0. - move => /eqP in H0. - by rewrite div.gcdnC. - } - induction n; trivial. - rewrite expnS. - assert (div.coprime j 2). - { - unfold div.coprime. - by apply /eqP. - } - now rewrite (@div.Gauss_gcdr j 2). - Qed. - - Lemma pow2_odd_inv (j n :nat) : - odd j -> - {k | (j * k mod (2^(S n)) = 1)%N}. - Proof. - intros. - assert (0 < j) by lia. - destruct (div.egcdnP (2^(S n)) H0). - exists km. - rewrite mulnC e addnC Nat.mod_add; try lia. - rewrite (gcd_odd_pow2 j n H). - apply Nat.mod_small. - destruct (pow2_S n). - move => /eqP in i0. - rewrite expnS i0; lia. - Qed. - - Lemma odd_pow_prim_root_inv (j n:nat) : - odd j -> - exists k, - forall (z2:T), - z2 ^+ (2^n) = -1 -> - ((z2 ^+ j) ^+ k) = z2. - Proof. - intros. - assert (2^(S n) <> 0)%N. - { - destruct (pow2_S (S n)). - move => /eqP in i. - rewrite i. - lia. - } - generalize (pow2_odd_inv j n H); intros. - destruct H1 as [k ?]. - exists k. - intros. - assert (z2 ^+ (2 ^ (S n)) = 1). - { - by rewrite expnS mulnC exprM H1 expr2 mulrNN mulr1. - } - rewrite -exprM. - rewrite (Nat.div_mod (j * k) _ H0) e. - by rewrite exprD exprM H2 expr1 Theory.expr1n mul1r. - Qed. - - Lemma odd_pow_prim_inv (n:nat) : - z ^+ (2^n) = -1 -> - forall j, - ((z ^+ j) ^+ (2^(S n) -1)) * (z ^+ j) = 1. - Proof. - intros. - rewrite -exprM -exprD /= -{2}(muln1 j) -mulnDr mulnC exprM. - rewrite addBnCAC; try lia. - by rewrite subnn add0n expnS mulnC exprM H expr2 mulrNN mulr1 expr1n. - Qed. - - Lemma odd_pow_prim_root_inj (j n:nat) (z2 : T) : - z ^+ (2^n) = -1 -> - z2 ^+ (2^n) = -1 -> - z <> z2 -> - odd j -> - z ^+ j <> z2 ^+ j. - Proof. - intros. - unfold not; intros. - destruct (odd_pow_prim_root_inv j n H2) as [k ?]. - apply H4 in H. - apply H4 in H0. - apply (f_equal (fun z => z ^+k)) in H3. - by rewrite H H0 in H3. - Qed. - - Lemma nth_root_eq' j k (n:nat) : n != 0%nat -> - j mod n = k mod n <-> - nth_root j n = nth_root k n. - Proof. - destruct n; [lia |]=>_. - apply nth_root_eq. - Qed. - - Lemma expn_n0 (c:nat) n : c != 0%nat -> expn c n != 0%nat. - Proof. - lia. - Qed. - - - Lemma iff_eqb (a b:bool) : (a <-> b) <-> a = b. - Proof. - destruct a; destruct b; simpl in *; firstorder. - elim H => //. - Qed. - - Lemma odd_rep (n : nat) : - odd n -> n = (2*(n%/2)+1)%N. - Proof. - intros. - generalize (divn_eq n 2); intros. - assert (n %% 2 = 1)%N. - { - by rewrite modn2 H. - } - by rewrite H1 in H0; lia. - Qed. - - - Lemma modn2_odd (x : nat) : - odd x <-> (x %% 2 = 1)%N. - Proof. - rewrite modn2. - by case: (odd _). - Qed. - - Lemma pow2_odd_rem1_odd (x n : nat) : - (x %% 2^(n.+1) = 1)%N -> odd x. - Proof. - intros. - rewrite expnS in H. - have: (x %% 2 = 1)%N. - { - assert (1 %% (2 * 2^n) = 1)%N. - { - rewrite modn_small; lia. - } - rewrite -{3}H0 in H. - move /eqP in H. - apply modn_muln in H. - move /eqP in H. - rewrite H. - rewrite modn_small; trivial. - } - by rewrite modn2_odd. - Qed. - - Lemma pow2_odd_inv_aux (j x n : nat) : - ((x * (2*j+1)) %% 2^(n.+1) = 1)%N -> - exists x0, x = (2*x0+1)%N. - Proof. - intros. - exists (x%/2)%N. - apply odd_rep. - apply pow2_odd_rem1_odd in H. - rewrite oddM in H. - move /andP in H. - tauto. - Qed. - - Lemma odd_pow_prim_roots_perm_eq (j n : nat) : - let l := mkseq (fun i => nth_root (2 * i + 1) (2 ^ (S n))) (2^n) in - perm_eq l (map (fun r => r^+(2*j+1)) l). - Proof. - assert (uniq (mkseq (fun i : nat => nth_root (2 * i + 1) (2 ^ n.+1)) (2 ^ n))). - { - apply /mkseq_uniqP. - intros ?????. - rewrite /in_mem /mem /= in H. - rewrite /in_mem /mem /= in H0. - destruct (pow2_S (S n)). - rewrite (eqP i) -nth_root_eq -(eqP i) in H1. - rewrite !Nat.mod_small in H1; try rewrite expnS; try lia. - } - assert (odd (2*j+1)) by lia. - destruct (pow2_odd_inv (2*j+1) n H0). - rewrite mulnC modulo_modn in e. - apply uniq_perm; trivial. - - rewrite map_inj_in_uniq // => a b. - rewrite /mkseq => /mapP-[i inth ->] /mapP-[k knth ->]. - do 2 rewrite pow_nth_root' ?expn_n0 //. - do 2 rewrite -nth_root_eq' ?expn_n0 //. - rewrite !modulo_modn => HH. - apply (f_equal (fun k => ((x %% 2^n.+1) * k) %% 2^n.+1)%N) in HH. - do 2 rewrite modnMm mulnA -modnMm e mul1n in HH. - by rewrite !modn_mod in HH. - - move=> a. - apply iff_eqb. - split; rewrite /mkseq -map_comp /comp => /mapP-[i inth ->]; apply/mapP. - + assert (exists x0, (2 * x0 +1)*(2 * j +1) mod 2^n.+1 = (2 * i +1)%N mod 2^n.+1). - { - destruct (pow2_odd_inv_aux _ _ _ e). - rewrite H1 in e. - exists ((2*(x0 * i)+x0+i)%N). - rewrite !modulo_modn. - replace (2 * (2 * (x0 * i) + x0 + i) + 1)%N with ((2 * x0 + 1)*(2 * i + 1))%N by lia. - rewrite -mulnA (mulnC (2 * i + 1)%N _) mulnA. - rewrite -modnMm e mul1n. - by rewrite modn_mod. - } - destruct H1. - exists (x0 mod 2^n). - * rewrite mem_iota. - apply /andP. - split; try lia. - rewrite add0n. - generalize (Nat.mod_upper_bound x0 (2^n)); lia. - * rewrite pow_nth_root'; try lia. - rewrite -nth_root_eq'; try lia. - rewrite -H1. - rewrite (mulnC (2 * x0 + 1)%N). - rewrite !modulo_modn -modnMm -(modnMm (2*j+1)%N). - do 2 f_equal. - rewrite muln_modr -expnS. - assert (1 < 2^n.+1) by lia. - generalize (modn_small H2); intros. - by rewrite -{6}H3 modnDm. - + exists ((2*(i*j)+i+j)%N mod 2^n). - * rewrite mem_iota. - apply /andP. - split; try lia. - rewrite add0n. - generalize (Nat.mod_upper_bound (2 * (i * j) + i + j) (2^n)); lia. - * rewrite pow_nth_root'; try lia. - rewrite -nth_root_eq'; try lia. - rewrite !modulo_modn muln_modr -expnS. - assert (1 < 2^n.+1) by lia. - generalize (modn_small H1); intros. - rewrite -{9}H2 modnDm. - f_equal; lia. - Qed. - - Lemma injective_finite_bijective {S} (l : list S) (f : S -> S) : - NoDup l -> - (forall s, In s l -> In (f s) l) -> - injective f -> - forall s, In s l <-> In s (map f l). - Proof. - intros. - split; intros. - - assert (NoDup (map f l)). - { - now apply FinFun.Injective_map_NoDup. - } - assert (incl l (map f l)). - { - apply NoDup_length_incl; trivial. - - now rewrite map_length. - - intros ??. - apply in_map_iff in H4. - destruct H4 as [? [??]]. - specialize (H0 x H5). - now rewrite H4 in H0. - } - now apply H4. - - apply in_map_iff in H2. - destruct H2 as [? [??]]. - rewrite -H2. - now apply H0. - Qed. - - Lemma injective_finite_permutation {S} (l : list S) (f : S -> S) : - NoDup l -> - (forall s, In s l -> In (f s) l) -> - injective f -> - Permutation l (map f l). - Proof. - intros. - apply NoDup_Permutation; trivial. - - now apply FinFun.Injective_map_NoDup. - - now apply injective_finite_bijective. - Qed. - - - - Lemma char_2_opp_eq : - one T = - (one T) <-> 2%N \in [char T]. - Proof. - unfold mem, in_mem; simpl. - rewrite mulr2n. - split;intros. - - apply (f_equal (fun (z : T) => 1 + z)) in H. - rewrite addrN in H. - by rewrite H. - - move=> /eqP in H. - apply (f_equal (fun (z : T) => z - 1)) in H. - by rewrite add0r -addrA addrN addr0 in H. - Qed. - - End unity. - - diff --git a/coq/FHE/encrypt.v b/coq/FHE/encrypt.v deleted file mode 100644 index ad7d1fa6..00000000 --- a/coq/FHE/encrypt.v +++ /dev/null @@ -1,764 +0,0 @@ -Require Import Lia List. -From mathcomp Require Import common ssreflect fintype bigop ssrnat matrix ring. -From mathcomp Require Import ssralg ssrfun ssrint seq. -From mathcomp Require Import generic_quotient ring_quotient. -From mathcomp Require Import poly mxpoly polydiv zmodp eqtype ssrbool. -From mathcomp Require Import intdiv. - -Import ssralg.GRing. -Require Import encode. - -Set Bullet Behavior "Strict Subproofs". - -Local Open Scope ring_scope. - -Lemma nat_of_ordK {p: nat} : cancel (@nat_of_ord (S (S p))) (natmul 1). -Proof. - move=> x. - by rewrite Zp_nat valZpK. -Qed. - -Lemma int_of_ordK {p: nat} : cancel (fun x:'Z_p => Posz (nat_of_ord x)) (intmul 1). -Proof. - move=> x. - by rewrite -pmulrn nat_of_ordK. -Qed. - -Section balance. - - Import ssrnum.Num.Syntax. - - (* range (-p/2, p/2] *) - Definition balanced_mod {p:nat} (x : 'Z_p):int := - if (x <= p./2)%N then x%:Z else x%:Z-p%:Z. - - (* range [-p/2, p/2) *) - Definition balanced_mod_lo {p:nat} (x : 'Z_p):int := - let xz := x %:Z in - let xzm := xz - p%:Z in - if -(p./2)%:Z <= xzm then xzm else xz. - - - Lemma absz_bound (x : int) (b : nat) : - - (b%:Z) <= x /\ x <= b%:Z <-> - (absz x <= b)%N. - Proof. - unfold absz. - split; intros. - - destruct H. - case_eq x; intros; lia. - - case_eq x; intros; rewrite H0 in H; lia. - Qed. - - Lemma absz_bound_lt (x : int) (b : nat) : - - (b%:Z) < x /\ x < b%:Z <-> - (absz x < b)%N. - Proof. - unfold absz. - split; intros. - - destruct H. - case_eq x; intros; lia. - - case_eq x; intros; rewrite H0 in H; lia. - Qed. - - Context {p : nat} {pbig:(1 < p)%nat}. - - Lemma Zp_intmul_Np (x : 'Z_p) : - x = (x%:Z - p%:Z)%:~R. - Proof. - generalize (intmul1_is_rmorphism (Zp_ringType (Zp_trunc p))); intros. - destruct H. - by rewrite base int_of_ordK -pmulrn char_Zp // oppr0 addr0. - Qed. - - Import order.Order.TotalTheory. - Import ssrnum.Num.Theory. - - Lemma balanced_mod_cong (x : 'Z_p) : - x = (balanced_mod x)%:~R. - Proof. - unfold balanced_mod. - case: (x <= p./2)%N. - - by rewrite int_of_ordK. - - by rewrite {1}(Zp_intmul_Np x). - Qed. - - Lemma balanced_mod_lo_cong (x : 'Z_p) : - x = (balanced_mod_lo x)%:~R. - Proof. - unfold balanced_mod_lo. - case: leP => _. - - by rewrite {1}(Zp_intmul_Np x). - - by rewrite int_of_ordK. - Qed. - - Lemma Zp_lt_p (x : 'Z_p): - x%:Z < p. - Proof. - generalize (ltn_ord x); intros. - by rewrite {2}Zp_cast in H. - Qed. - - Lemma Zp_lt_p_N (x : 'Z_p): - (x < p)%N. - Proof. - generalize (ltn_ord x); intros. - by rewrite {2}Zp_cast in H. - Qed. - - Lemma balanced_mod_range1 (x : 'Z_p): - balanced_mod x <= p./2. - Proof. - unfold balanced_mod. - generalize (Zp_lt_p_N x). - intros. - case leqP; lia. - Qed. - - Lemma balanced_mod_lo_range1 (x : 'Z_p): - balanced_mod_lo x <= p.-1./2. - Proof. - unfold balanced_mod_lo. - generalize (Zp_lt_p x). - case: (boolP (- (p./2)%:Z <= _)) => le1; lia. - Qed. - - Lemma balanced_mod_range2 (x : 'Z_p): - -((p.-1./2)%:Z) <= balanced_mod x. - Proof. - unfold balanced_mod. - case: leqP => HH; try lia. - Qed. - - Lemma balanced_mod_lo_range2 (x : 'Z_p): - -((p./2)%:Z) <= balanced_mod_lo x. - Proof. - unfold balanced_mod_lo. - case: (boolP (_ <= Posz x - p%:Z)) => le1; lia. - Qed. - - Lemma balanced_mod_abs_range (x : 'Z_p): - (absz (balanced_mod x) <= p./2)%N. - Proof. - apply absz_bound. - split. - - generalize (balanced_mod_range2 x); lia. - - apply balanced_mod_range1. - Qed. - - Lemma balanced_mod_lo_abs_range (x : 'Z_p): - (absz (balanced_mod_lo x) <= p./2)%N. - Proof. - apply absz_bound. - split. - - apply balanced_mod_lo_range2. - - generalize (balanced_mod_lo_range1 x); lia. - Qed. - - Lemma balanced_mod_unique (c1 c2 : int) : - c1 <= p./2 -> - c2 <= p./2 -> - -((p.-1./2)%:Z) <= c1 -> - -((p.-1./2)%:Z) <= c2 -> - ((c1 - c2) %% p)%Z = 0 -> - c1 = c2. - Proof. - intros. - case (leP (0%Z) (c1 - c2)%Z) => le0. - - assert (le0_lep:0%Z <= c1 - c2 < p%:Z). - { - apply /andP; lia. - } - generalize (modz_small le0_lep); lia. - - assert (le0_lep:0%Z <= c2 - c1 < p%:Z). - { - apply /andP; lia. - } - assert (((c2 - c1) %% p)%Z = 0). - { - replace (c2 - c1)%Z with (-1 * (c1 - c2))%Z by lia. - by rewrite -modzMmr H3 mulr0 mod0z. - } - generalize (modz_small le0_lep); lia. - Qed. - -End balance. - -Section encrypted_ops. - - Variable (q:nat). - Hypothesis (qodd : (odd q)). - - Variable (err_poly secret_poly a_poly : {poly 'Z_q}). - Variable (ρ : nat). (* bound for errs *) - - (* err_poly is small, a_poly is random over 'Z_q *) - - Definition public_key := (-a_poly * secret_poly + err_poly, a_poly). - - Definition encrypt (m : {poly 'Z_q}) : ({poly 'Z_q} * {poly 'Z_q}) := - (m + fst public_key, snd public_key). - - Definition encrypt_z (m : {poly int}) := encrypt (red_poly m q). - - Definition rounded_div (num : int) (den : nat) := - let denz := den %:Z in - let q := (num %/ denz)%Z in - let rem := num - q * denz in - if absz rem <= den./2 then q else q+1. - - Lemma add_opp [R : comRingType] (x : R) : - (-x) + x = 0. - Proof. - ring. - Qed. - - Lemma rounded_div_rem_small (num : int) (den : nat) : - (0 < den)%N -> - absz (num - (rounded_div num den) * (den%:Z))%Z <= den ./2. - Proof. - intros. - apply absz_bound. - unfold rounded_div. - case: (boolP ((`|(num - (num %/ den)%Z * den)%R|) <= _)) => HH. - - apply absz_bound in HH. - destruct HH. - split; lia. - - split; try lia. - Qed. - - Definition coef_norm {qq:nat} (p : {poly 'Z_qq}) := - list_max (map absz (map balanced_mod p)). - - Hypothesis (err_poly_small : coef_norm err_poly <= ρ). - - Definition decrypt mpair := (fst mpair) + (snd mpair) * secret_poly. - - Lemma encrypt_decrypt (m : {poly 'Z_q}) : - decrypt (encrypt m) = m + err_poly. - Proof. - unfold decrypt, encrypt, public_key, fst, snd. - ring. - Qed. - -(* - (* following already defined in ssralg *) - Definition add_pair (p1 p2 : ({poly 'Z_q} * {poly 'Z_q})) := - (fst p1 + fst p2, snd p1 + snd p2). -*) - - Definition scale_pair (m : {poly 'Z_q}) (p : ({poly 'Z_q} * {poly 'Z_q})) := - (m * fst p, m * snd p). - - Definition mul_pair (p1 p2 : ({poly 'Z_q} * {poly 'Z_q})) := - (fst p1 * fst p2, (fst p1 * snd p2 + snd p1 * fst p2, snd p1 * snd p2)). - - Lemma CKKS_add (m1 m2 : {poly 'Z_q}) : - decrypt (add_pair (encrypt m1) (encrypt m2)) = - decrypt (encrypt m1) + decrypt(encrypt m2). - Proof. - rewrite !encrypt_decrypt. - unfold add_pair, decrypt, encrypt, public_key, fst, snd. - ring. - Qed. - - Lemma CKKS_scale (m1 m2 : {poly 'Z_q}) : - decrypt (scale_pair m1 (encrypt m2)) = - m1 * decrypt(encrypt m2). - Proof. - unfold scale_pair, encrypt, decrypt, public_key, fst, snd. - ring. - Qed. - - Definition decrypt_mul trip := fst trip + secret_poly * decrypt (snd trip). - - Lemma CKKS_mul_trip (m1 m2 : {poly 'Z_q}) : - decrypt_mul (mul_pair (encrypt m1) (encrypt m2)) = - decrypt (encrypt m1) * decrypt (encrypt m2). - Proof. - unfold mul_pair, encrypt, decrypt_mul, decrypt, public_key, fst, snd. - ring. - Qed. - - Variable (p:nat). (* relin_modulus *) - Hypothesis pbig : p > q. - - Definition pq_embed (c : 'Z_q) : 'Z_(p*q) := (balanced_mod c)%:~R. - - Definition secret_p := map_poly pq_embed secret_poly. - - Variable (relin_err relin_a : {poly 'Z_(p*q)}). - Hypothesis (relin_err__small : coef_norm relin_err <= ρ). - - Definition rescale (q1 q2 : nat) (c : 'Z_(q1*q2)) : 'Z_q2 := - (rounded_div (balanced_mod c) q1)%:~R. - - Definition rescale_gen (q1 q2 : nat) (c : 'Z_q1) : 'Z_q2 := - (rounded_div ((balanced_mod c) * q2) q1)%:~R. - - Definition scale_up (q1 q2 : nat) (c : 'Z_q1) : 'Z_(q1*q2) := - inZp (muln q2 c). - - Lemma scale_up_additive (q1 q2 : nat): - additive (scale_up (Zp_trunc q1).+2 (Zp_trunc q2).+2). - Proof. - intros x y. - rewrite /scale_up /add /opp /= /inZp. - apply ord_inj => /=. - set q1' := (Zp_trunc q1).+2. - set q2' := (Zp_trunc q2).+2. - rewrite {2 4}(@Zp_cast (q1' * q2')) //. - rewrite !div.modnDmr !div.modnDml. - rewrite div.muln_modr. - unfold q1', q2'. - replace ((Zp_trunc ((Zp_trunc q1).+2 * (Zp_trunc q2).+2)).+2 ) with - ((Zp_trunc q2).+2 * (Zp_trunc q1).+2)%N at 1 by (unfold Zp_trunc; lia). - rewrite div.modn_mod. - f_equal; try lia. - rewrite div.modn_small. - - rewrite mulnDr mulnBr. - f_equal. - f_equal. - rewrite mulnC. - unfold Zp_trunc. - lia. - - rewrite mulnC. - have: (ltn y (Zp_trunc q1).+2). - + apply ltn_ord. - + assert (0 < (Zp_trunc q2).+2). - { - unfold Zp_trunc; lia. - } - by rewrite ltn_mul2r. - Qed. - - Definition rescale1 (q1 q2 : nat) (c : 'Z_(q1*q2)) : 'Z_q2 := inZp c. - - Lemma rescale1_is_rmorphism (q1 q2 : nat) : - rmorphism (rescale1 (Zp_trunc q1).+2 (Zp_trunc q2).+2). - Proof. - unfold rescale1. - generalize (intmul1_is_rmorphism (Zp_ringType (Zp_trunc q2))); intros. - destruct H as [? [? ?]]. - constructor. - - intros x y. - rewrite /add/opp/=/Zp_add /= /inZp. - apply ord_inj => /=. - set q1' := (Zp_trunc q1).+2. - set q2' := (Zp_trunc q2).+2. - rewrite {2 4}(@Zp_cast (q1' * q2')) //. - rewrite !div.modnDmr !div.modnDml div.modn_dvdm. - + suff: Posz (div.modn (x + (q1' * q2' - y)) q2') = Posz (div.modn (x + (q2' - div.modn y q2')) q2') - by inversion 1. - rewrite -!modz_nat !PoszD -!ssrint.subzn. - * rewrite -modzDmr -(modzDml (muln q1' q2')). - rewrite PoszM modzMl add0r modzDmr -!modz_nat. - rewrite -(modzDmr x (Posz q2' - _)). - rewrite -(modzDml q2' _). - rewrite modzz add0r modzNm. - by rewrite modzDmr. - * apply ltnW. - by apply div.ltn_pmod. - * apply ltnW. - apply ltn_ord. - + by rewrite div.dvdn_mull. - - constructor. - + intros x y. - rewrite -!Zp_nat !pmulrn. - rewrite -m -!pmulrn !Zp_nat /inZp. - apply ord_inj => //=. - rewrite div.modn_dvdm //. - rewrite {1}Zp_cast //. - rewrite (@Zp_cast ((Zp_trunc q1).+2 * (Zp_trunc q2).+2)) //. - by rewrite div.dvdn_mull //. - + by rewrite /= div.modn_small //. - Qed. - - Canonical rescale1_rmorphism (q1 q2: nat) := RMorphism (rescale1_is_rmorphism q1 q2). - - Lemma cdivqq_int (q1 q2 : nat) (c : int): - (0 < q2)%N -> - (c %/ q1)%Z = ((c * q2) %/ (q1 * q2)%N)%Z. - Proof. - intros. - rewrite -(@divzMpr q2%:Z); [| lia]. - do 2 f_equal. - Qed. - - Lemma lt_muln_iff (n1 n2 n3 : nat) : - (n1 < n2)%N <-> (n1 * (S n3) < n2 * (S n3))%N. - Proof. - induction n3; lia. - Qed. - - Lemma le_half_odd (n1 n2 : nat) : - odd n2 -> - (n1 <= n2./2)%N <-> (n1.*2.+1 <= n2)%N. - Proof. - lia. - Qed. - - Lemma le_half_mul_odd (n1 n2 n3 : nat) : - odd n2 -> - odd n3 -> - (n1 <= n2./2)%N <-> (n1 * n3 <= (n2 * n3)./2)%N. - Proof. - intros. - rewrite le_half_odd // le_half_odd; try lia. - replace ((n1 * n3).*2) with ((n1.*2)*n3)%N by lia. - replace n3 with (n3.-1.+1) by lia. - apply lt_muln_iff. - Qed. - - Lemma rounded_div_scale_div (q1 q2 : nat) (c : int): - odd q1 -> - odd q2 -> - rounded_div c q1 = rounded_div (c * q2) (q1 * q2). - Proof. - intros. - assert (0 < q2)%N by lia. - rewrite /rounded_div -!cdivqq_int //. - have: ((c * q2 - (c %/ q1)%Z * (q1 * q2)%N)%R = - (c - (c %/ q1)%Z * q1)%R * q2) by lia. - move ->. - rewrite abszM absz_nat. - generalize (le_half_mul_odd `|(c - (c %/ q1)%Z * q1)%R| q1 q2 H H0); intros. - case: leP; case: leP; lia. - Qed. - - Lemma rescale_gen_prop (q1 q2 : nat) (c : 'Z_(q1*q2)): - odd q1 -> - odd q2 -> - rescale q1 q2 c = rescale_gen (q1 * q2) q2 c. - Proof. - intros. - unfold rescale, rescale_gen, balanced_mod. - by rewrite -rounded_div_scale_div. - Qed. - - Definition red_p_q (c : 'Z_(p*q)) : 'Z_q := rescale p q c. - - Definition relin_V2_aux (c2 : {poly 'Z_q}) := - let b := - relin_a * secret_p + (secret_p ^+ 2)*+p + relin_err in - let cp := map_poly pq_embed c2 in - (map_poly red_p_q (cp * b), map_poly red_p_q (cp * relin_a)). - - Definition relin_V2 trip := - add_pair (fst trip, fst (snd trip)) - (relin_V2_aux (snd (snd trip))). - - Definition CKKS_mul (p1 p2 : ({poly 'Z_q} * {poly 'Z_q})) : - ({poly 'Z_q} * {poly 'Z_q}) := - relin_V2 (mul_pair p1 p2). - -End encrypted_ops. - -Section rotation. - - (* show p x -> p (x^k) is a morphism *) - Definition poly_shift [R:ringType] (k : nat) (p : {poly R}) : {poly R} - := comp_poly 'X^k p. - - Definition poly_shift_alt [R:ringType] (k : nat) (p : {poly R}) : {poly R} - := \poly_(i < (k * (seq.size p).-1).+1) (if div.dvdn k i then p`_(div.divn i k) else 0). - - Lemma poly_shift_altE [R:ringType] (k : nat) (p : {poly R}) : - poly_shift k.+1 p = poly_shift_alt k.+1 p. - Proof. - case: (@eqP _ (seq.size p) 0%nat). - - move/seq.size0nil. - rewrite -polyseq0 => /poly_inj->. - rewrite /poly_shift /poly_shift_alt comp_poly0. - apply polyP => i. - rewrite coef_poly !coef0. - case: ltP => //. - by case: div.dvdnP => //. - - move=> pn0. - apply polyP => i. - rewrite /poly_shift /poly_shift_alt. - rewrite !coef_comp_poly_Xn //= coef_poly /=. - case: div.dvdnP. - + move=>[m ->]. - rewrite div.mulnK //. - case: ltP => // nlt. - rewrite seq.nth_default //. - move /ltP: nlt. - rewrite mulnC ltnS leq_pmul2l //. - rewrite leqNgt Bool.negb_involutive. - by case: (seq.size p) pn0. - + by case: ltP. - Qed. - - Lemma poly_shift_altE' [R:ringType] (k : nat) (p : {poly R}) : k != 0%nat -> - poly_shift k p = poly_shift_alt k p. - Proof. - destruct k => // _. - apply poly_shift_altE. - Qed. - - Lemma poly_shift_1 [R:ringType] (k : nat): - @poly_shift R k 1 = 1. - Proof. - by rewrite /poly_shift comp_polyC. - Qed. - - Lemma poly_shift_is_rmorphism [R:comRingType] (k : nat) : - rmorphism (poly_shift (R := R) k). - Proof. - unfold poly_shift. - constructor. - - intros ??. - by rewrite comp_polyB. - - split. - + intros ??. - by rewrite comp_poly_multiplicative. - + by rewrite comp_polyC polyC1. - Qed. - - Lemma poly_shift_injective [R:ringType] (k:nat) : injective (poly_shift (R:=R) (S k)). - Proof. - move=> a b eqq. - apply polyP => i. - apply (f_equal (coefp (k.+1 * i))) in eqq. - move: eqq. - rewrite /poly_shift /=. - rewrite !coef_comp_poly_Xn //=. - rewrite !div.dvdn_mulr //. - by rewrite !div.mulKn //. - Qed. - - Lemma poly_shift1_id [R:ringType] (p : {poly R}): - @poly_shift R 1 p = p. - Proof. - apply polyP => i. - rewrite /poly_shift /=. - rewrite !coef_comp_poly_Xn //=. - by rewrite div.dvd1n div.divn1. - Qed. - - Lemma size_poly_shift [R:ringType] (k:nat) (p : {poly R}) (pn0:p!=0) : - seq.size (poly_shift (k.+1) p) = (k.+1 * (seq.size p).-1).+1%nat. - Proof. - rewrite poly_shift_altE. - rewrite size_poly_eq //=. - rewrite div.dvdn_mulr ?div.dvdnn //. - rewrite div.mulKn //. - by rewrite -lead_coefE lead_coef_eq0. - Qed. - - Lemma size_poly_shift' [R:ringType] (k:nat) (p : {poly R}) (pn0:p!=0) : - k != 0%nat -> - seq.size (poly_shift k p) = (k * (seq.size p).-1).+1%nat. - Proof. - elim: k => //. - move=> k _ _. - by apply size_poly_shift. - Qed. - - Definition poly_unshift [R:ringType] (k : nat) (p : {poly R}) := - \poly_(i < (div.divn (seq.size p).-1 k).+1) (p`_(k*i)). - - Lemma poly_shiftK [R:comRingType] (k: nat) : cancel (@poly_shift R (S k)) (@poly_unshift R (S k)). - Proof. - move=> p. - case: (@eqP _ p 0). - - move=> -> /=. - rewrite /poly_shift comp_poly0 /poly_unshift. - apply polyP=> i. - rewrite coef_poly !polyseq0 /= !seq.nth_nil. - by case: ltP. - - rewrite /poly_unshift => /eqP-pn0. - rewrite size_poly_shift //. - rewrite poly_shift_altE /poly_shift_alt. - apply polyP=> i. - rewrite coef_poly => /=. - rewrite div.mulKn //. - rewrite -polySpred //. - case: ltP. - + move=> ilt. - rewrite coef_poly div.mulKn //. - rewrite div.dvdn_mulr ?div.dvdnn //. - rewrite ltnS leq_pmul2l //. - rewrite polySpred // in ilt. - rewrite -ltnS. - by move/ltP: ilt => ->. - + move=> inlt. - rewrite seq.nth_default //. - rewrite leqNgt. - by apply/ltP. - Qed. - - - Lemma comp_poly_exp_polyX [R:ringType] j k : - (polyX R) ^+ (j * k) = comp_poly ('X^ j) ('X^ k). - Proof. - by rewrite comp_Xn_poly /= -exprM. - Qed. - - Lemma poly_shiftM [R:comRingType] (j k: nat) (p: {poly R}) : - poly_shift (j * k) p = poly_shift j (poly_shift k p). - Proof. - by rewrite /poly_shift -comp_polyA comp_poly_exp_polyX. - Qed. - - Lemma lin_div_odd_power [R:ringType] k : - odd k -> - Pdiv.Ring.rdvdp (R := R) ('X + 1%:P) ('X^k + 1%:P). - Proof. - rewrite -{1}(opprK 1%:P). - replace (- polyC (R:=R) 1) with (polyC (R:=R) (-1)). - - intros. - rewrite Pdiv.Ring.rdvdp_XsubCl /root hornerD hornerXn hornerC. - by rewrite -signr_odd H /= expr1 addrC addrN. - - by rewrite polyCN polyC1. - Qed. - - Lemma rdvdp_comp_poly_monic [R:comRingType] (r p q : {poly R}) : - p \is monic -> - p \Po r \is monic -> - Pdiv.Ring.rdvdp p q -> - Pdiv.Ring.rdvdp (p \Po r) (q \Po r). - Proof. - move=> monp monpr. - have [-> | pn0] := eqVneq p 0. - - by rewrite comp_poly0 !Pdiv.Ring.rdvd0p; move/eqP->; rewrite comp_poly0. - - rewrite Pdiv.ComRing.rdvdp_eq. - rewrite (monicP monp) expr1n /= scale1r. - set s := Pdiv.Ring.rdivp (R:=R) _ _; move/eqP=> Hq. - apply: (@mathcomp.algebra.polydiv.Pdiv.RingMonic.eq_rdvdp _ _ _ (s \Po r)). - + trivial. - + by rewrite -comp_polyM -{}Hq. - Qed. - - Lemma pow2_div_odd_power [R:comRingType] k n : - odd k -> - Pdiv.Ring.rdvdp (R := R) ('X^(2^n) + 1%:P) ('X^k ^+(2^n) + 1%:P). - Proof. - move=> oddk. - case: (@eqVneq _ n 0%nat). - - move=> ->. - rewrite expn0 !expr1. - by apply lin_div_odd_power. - - move=> nn0. - move: (rdvdp_comp_poly_monic (R:=R) ('X^(2 ^ n)) ('X + 1%:P) ('X^k + 1%:P)). - rewrite lin_div_odd_power //. - rewrite (Xn_add_c_monic 0). - rewrite !comp_polyD !comp_polyX !comp_polyC. - have-> : 'X^(2 ^ n) + 1%:P \is @monic R. - { - case: (@eqP _ (expn 2 n) 0%nat) =>eqq. - - lia. - - destruct (expn 2 n); [lia |]. - apply Xn_add_c_monic. - } - rewrite comp_Xn_poly -!exprM. - rewrite [muln (2^n) k]mulnC. - by apply. - Qed. - -End rotation. - - Require Import Reals nth_root encode. - From mathcomp Require Import Rstruct complex. - - Lemma poly_shift_eval [S : comRingType] (p : {poly S}) (k : nat) (v : S) : - p.[v^+k] = (poly_shift k p).[v]. - Proof. - unfold poly_shift. - by rewrite horner_comp hornerXn. - Qed. - - Lemma poly_shift_C (p : {poly R}) (k : nat) : - poly_shift k (map_poly RtoC p) = map_poly RtoC (poly_shift k p). - Proof. - by rewrite /poly_shift map_comp_poly map_polyXn. - Qed. - - Lemma poly_shift_eval_C (p : {poly R}) (k:nat) (v : R[i]) : - (map_poly RtoC p).[v^+k] = (map_poly RtoC (poly_shift k p)).[v]. - Proof. - by rewrite poly_shift_eval poly_shift_C. - Qed. - - Lemma conj_poly_eval_pow (p : {poly R}) (i j :nat) : - let v := nth_root i (S j) in - conjc ((map_poly RtoC p).[v]) = (map_poly RtoC (poly_shift j p)).[v]. - Proof. - simpl. - rewrite -poly_shift_eval_C. - rewrite rpoly_eval_conj. - f_equal. - by rewrite -conj_pow_nth_root. - Qed. - - From mathcomp Require Import ssrnat div. - Lemma poly_odd_pow_prim_roots_perm_eq (j n : nat) (p : {poly R[i]}): - let l := mkseq (fun i => nth_root (2 * i + 1) (2 ^ (S n))) (2^n) in - perm_eq (map (fun x => p.[x]) l) - (map (fun x => (poly_shift (2*j+1) p).[x]) l). - Proof. - pose (l := mkseq (fun i => nth_root (2 * i + 1) (2 ^ (S n))) (2^n)). - have /= ->: map (fun x => (poly_shift (2*j+1) p).[x]) l = - map (fun x => p.[x]) (map (fun x => x^+(2*j+1)) l). - { - rewrite -map_comp /ssrfun.comp. - apply eq_map=> a. - by rewrite poly_shift_eval. - } - apply perm_map. - apply odd_pow_prim_roots_perm_eq. - Qed. - - Lemma nth_root_pow_trans n i k : - exists j, - (nth_root (2 * i + 1) (2 ^ (S n))) ^+ (2 * j + 1) = - nth_root (2 * k + 1) (2 ^ (S n)). - Proof. - assert (exists j, - (2 * i + 1) * (2 * j + 1) = 2 * k + 1 %[mod 2^(S n)]). - { - generalize (pow2_odd_inv (2 * i + 1) n); intros. - destruct H; try lia. - generalize (pow2_odd_inv_aux i x n); intros. - rewrite (mulnC x _) in H. - rewrite modulo_modn in e. - specialize (H e). - destruct H. - rewrite H in e. - exists (2 * x0 * k + x0 + k)%N. - replace (2 * (2 * x0 * k + x0 + k) + 1)%N with - ((2 * x0 + 1) * (2 * k + 1))%N by lia. - rewrite mulnA. - assert (1 %% 2^n.+1 = 1)%N. - { - rewrite modn_small; lia. - } - rewrite -{6}H0 in e. - apply (f_equal (fun z => (z * (2 * k + 1) %% 2^ n.+1)%N)) in e. - by rewrite H0 mul1n modnMml in e. - } - destruct H. - exists x. - rewrite pow_nth_root'; try lia. - apply nth_root_eq'; try lia. - by rewrite mulnC !modulo_modn. - Qed. - - Lemma poly_odd_pow_prim_roots_perm_trans (n : nat) (p : {poly R[i]}): - forall i k, - exists j, - (poly_shift (2*j+1) p).[nth_root (2 * i + 1) (2 ^ (S n))] = - p.[nth_root (2 * k + 1) (2 ^ (S n))]. - Proof. - intros. - destruct (nth_root_pow_trans n i k). - exists x. - by rewrite -poly_shift_eval H. - Qed. - - - - - - - diff --git a/coq/FHE/nth_root.v b/coq/FHE/nth_root.v deleted file mode 100644 index a5dafc2f..00000000 --- a/coq/FHE/nth_root.v +++ /dev/null @@ -1,885 +0,0 @@ -Require Import Reals Lra Lia List. -From mathcomp Require Import complex ssreflect common eqtype ssrint ssrnat Rstruct. -Import ssralg.GRing. -Import ssralg. - -Ltac coq_lra := lra. - -From mathcomp Require Import lra. - -Set Bullet Behavior "Strict Subproofs". - -Lemma S_INR_not_0 n : - INR (S n) <> Rdefinitions.R0. -Proof. - rewrite S_INR. - generalize (pos_INR n). - coq_lra. -Qed. - -Lemma S_INR_n0 n : is_true (INR (S n) != (zero R_ringType)). -Proof. - intros. - move: (S_INR_not_0 n) => HH. - by case eqP. -Qed. - -(* represent complex number as pair *) -Definition nth_root (j n : nat) : R[i] := - let c := (2*PI*INR(j)/INR(n))%R in - ((cos c) +i* (sin c))%C. - -Local Open Scope ring_scope. -Delimit Scope complex_scope with C. -Local Open Scope complex_scope. - -Definition RtoC (x : R) := Complex x Rdefinitions.R0. -Definition C1 := Complex R1 Rdefinitions.R0. -Definition C0 := Complex Rdefinitions.R0 Rdefinitions.R0. -Definition C := ComplexField.complex_ringType R_fieldType. - -Lemma nth_root_0 n : - nth_root 0 (S n) = 1. -Proof. - unfold nth_root. - assert ((2 * PI * INR 0 / INR (S n))%R = 0%R). - { - unfold INR at 1. - by rewrite mulr0 mul0r. - } - by rewrite H /zero /= cos_0 sin_0. -Qed. - -Lemma nth_root_2PI n j : - nth_root (j * (S n)) (S n) = 1. -Proof. - unfold nth_root. - replace (2 * PI * INR (j * S n) / INR (S n))%R with - (0 + 2 * (INR j) * PI)%R. - - by rewrite cos_period sin_period /zero /= cos_0 sin_0. - - rewrite add0r mult_INR -mulrA -mulrA. - rewrite (mulrC _ PI) -(mulrA 2 _). - f_equal. - f_equal. - rewrite -mulrA divff. - + by rewrite mulr1. - + by apply S_INR_n0. -Qed. - -Lemma nth_root_2PI_plus n j k : - nth_root (j + k * (S n)) (S n) = nth_root j (S n). -Proof. - unfold nth_root. - replace (2 * PI * INR (j + k * S n) / INR (S n))%R with - (2 * PI * INR(j)/INR(S n) + 2 * INR k * PI)%R. - - now rewrite cos_period; rewrite sin_period. - - rewrite plus_INR; rewrite mult_INR. - rewrite (mulrDr (2 * PI) _ _) mulrDl. - f_equal. - rewrite -mulrA (mulrC _ PI) -mulrA -mulrA. - f_equal. - f_equal. - rewrite -mulrA divff. - + by rewrite mulr1. - + by apply S_INR_n0. -Qed. - -Definition nth_roots (n:nat) := - map (fun j => nth_root j n) (seq 0 n). - - -Lemma de_moivre (x : R) (n : nat) : - exp (cos x +i* sin x) n = (cos ((INR n) * x)%R +i* sin ((INR n) * x)%R). -Proof. - rewrite /mul /= -iter_mulr_1. - induction n. - - simpl. - rewrite Rmult_0_l. - now rewrite cos_0; rewrite sin_0. - - simpl iter. - rewrite IHn S_INR Rmult_plus_distr_r Rmult_1_l. - rewrite cos_plus sin_plus /= /mul /=. - rewrite /mul /add /opp /=. - f_equal; ring. - Qed. - -Lemma exp_nth_root j n e : - exp (nth_root j (S n)) e = nth_root (e * j) (S n). -Proof. - unfold nth_root. - rewrite de_moivre mult_INR. - assert ((INR e * (2 * PI * INR j / INR n.+1)%R)%R = 2 * PI * (INR e * INR j)%R / INR n.+1). - { - replace (S n) with (n + 1)%nat by lia. - unfold mul, inv; simpl. - field. - } - by rewrite -H. -Qed. - -Lemma exp_nth_root_comm j n e : - exp (nth_root j (S n)) e = exp (nth_root e (S n)) j. -Proof. - do 2 rewrite exp_nth_root. - f_equal. - apply mulnC. -Qed. - -Lemma nth_root_npow j n : - exp (nth_root j (S n)) (S n) = 1. -Proof. - by rewrite exp_nth_root mulnC nth_root_2PI. -Qed. - -Lemma minus_mod (j1 j2 n : nat) : - j1 mod (S n) = j2 mod (S n) -> - (j2 - j1) mod (S n) = 0%nat. -Proof. - intros eqq1. - destruct (le_dec j1 j2). - - generalize (Zdiv.Zminus_mod (Z.of_nat j2) (Z.of_nat j1) (Z.of_nat (S n))) - ; intros HH. - rewrite <- Nat2Z.inj_sub in HH by trivial. - repeat rewrite <- Nat2Z.inj_mod in HH. - rewrite -eqq1 Z.sub_diag Zdiv.Zmod_0_l in HH. - apply (f_equal Z.to_nat) in HH. - now rewrite Nat2Z.id in HH. - - unfold subn, subn_rec. - rewrite Minus.not_le_minus_0_stt; trivial. - now apply Nat.mod_0_l. -Qed. - -Lemma nth_root_mod j1 j2 n : - j1 mod (S n) = j2 mod (S n) -> - nth_root j1 (S n) = nth_root j2 (S n). -Proof. - intros. - destruct (le_dec j1 j2). - - assert (exists (k:nat), (j2 = j1 + k * (S n))%N). - { - generalize (Nat.div_mod_eq (j2 - j1) (S n)); intros. - exists ((j2 - j1)/(S n))%N. - rewrite minus_mod in H0; trivial; lia. - } - destruct H0. - rewrite H0. - now rewrite nth_root_2PI_plus. - - assert (exists (k:nat), (j1 = j2 + k * (S n))%N). - { - generalize (Nat.div_mod_eq (j1 - j2) (S n)); intros. - exists ((j1 - j2)/(S n))%N. - rewrite minus_mod in H0; lia. - } - destruct H0. - rewrite H0. - now rewrite nth_root_2PI_plus. - Qed. - -Lemma prim_nth_root j n : - nth_root j (S n) = exp (nth_root 1 (S n)) j. -Proof. - rewrite exp_nth_root. - f_equal. - lia. - Qed. - -Lemma nth_root_not_0 j n : - nth_root j (S n) <> 0. -Proof. - rewrite /nth_root. - generalize (cos_sin_0 (2 * PI * INR j / INR (S n))%R); intros. - intros ?. - apply H. - split. - - by apply (f_equal (fun c => Re c)) in H0. - - by apply (f_equal (fun c => Im c)) in H0. - Qed. - -Lemma cos1_sin0 (x : R) : - cos x = 1 -> - sin x = 0. -Proof. - intros eqq1. - generalize (cos2 x). - rewrite eqq1; intros eqq2. - rewrite Rsqr_1 in eqq2. - apply Rsqr_0_uniq. - coq_lra. -Qed. - -Section sin_cos. - Local Open Scope R_scope. - -Lemma cosneg1_sin0 (x : R) : - cos x = - 1 -> - sin x = 0. -Proof. - intros eqq1. - generalize (cos2 x). - rewrite eqq1; intros eqq2. - rewrite -Rsqr_neg Rsqr_1 in eqq2. - apply Rsqr_0_uniq. - coq_lra. -Qed. - -Lemma cos_sin0_alt (x : R) : - sin x = 0 <-> - Rsqr(cos x) = 1. -Proof. - split; intro eqq. - - generalize (sin2_cos2 x); intros. - rewrite eqq in H. - rewrite Rsqr_0 in H. - now rewrite Rplus_0_l in H. - - generalize (sin2 x); intros. - rewrite eqq in H. - rewrite Rminus_eq_0 in H. - now apply Rsqr_0_uniq in H. -Qed. - -Lemma Rsqr_1_iff (x : R) : - Rsqr x = 1 <-> - x = 1 \/ x = -1. -Proof. - generalize (Rsqr_eq x 1); intros. - rewrite Rsqr_1 in H. - split; intros. - - now apply H. - - unfold Rsqr. - destruct H0; rewrite H0; coq_lra. -Qed. - -Lemma cos_sin0 (x : R) : - sin x = 0 <-> - cos x = 1 \/ cos x = -1. -Proof. - split; intros. - - apply cos_sin0_alt in H. - now apply Rsqr_1_iff in H. - - apply Rsqr_1_iff in H. - now apply cos_sin0_alt in H. -Qed. - - -Lemma cos_eq_1_aux_pos (x : R) : - cos x = 1 -> - exists k, x = (PI * IZR(k))%R. -Proof. - intros eqq1. - generalize (cos1_sin0 _ eqq1); intros eqq2. - apply sin_eq_0_0 in eqq2. - destruct eqq2 as [k eqqk]. - exists k. - unfold mul; simpl; coq_lra. -Qed. - -Lemma cos_eq_1_aux_neg (x : R) : - cos x = - 1 -> - exists k, x = (PI * IZR(k))%R. -Proof. - intros eqq1. - generalize (cosneg1_sin0 _ eqq1); intros eqq2. - apply sin_eq_0_0 in eqq2. - destruct eqq2 as [k eqqk]. - exists k. - unfold mul; simpl; coq_lra. -Qed. - -Lemma sin_eq_0_aux (x : R) : - sin x = 0 -> - exists k, x = (PI * IZR(k))%R. -Proof. - intros. - apply cos_sin0 in H. - destruct H. - - now apply cos_eq_1_aux_pos. - - now apply cos_eq_1_aux_neg. -Qed. - -Lemma cos_eq_1_1 : - forall k:Z, - cos (IZR k * 2 * PI)%R = 1. -Proof. - intros k. - assert (forall n, cos (INR n * 2 * PI) = 1%R). { - intros n;induction n as [|n IHn]. - { change (INR 0) with Rdefinitions.R0. - rewrite !Rmult_0_l. - exact cos_0. } - rewrite S_INR !Rmult_plus_distr_r cos_plus IHn. - rewrite !Rmult_1_l cos_2PI sin_2PI Rmult_0_r Rminus_0_r. - reflexivity. - } - destruct (Z.abs_or_opp_abs k). - - replace (IZR k) with (INR (Z.to_nat k)). - { apply H. } - rewrite INR_IZR_INZ. - f_equal. - apply Z2Nat.id. - lia. - - replace (IZR k) with (- INR (Z.to_nat (- k)))%R. - + by rewrite mulNr mulNr cos_neg H. - + rewrite INR_IZR_INZ. - rewrite <-opp_IZR. f_equal. - lia. -Qed. - -Lemma cos_lt_1 (x : R) : - 0 < x -> - x < 2*PI -> - cos x < 1. -Proof. - intros. - generalize (COS_bound x); intros. - generalize PI_RGT_0; intro pi_gt. - destruct H1. - assert (cos x <> 1)%R. - { - unfold not. - intros. - generalize (cos_eq_1_aux_pos x H3); intros. - destruct H4. - rewrite H4 mulrC /mul /= in H0. - apply Rmult_lt_reg_r in H0; trivial. - rewrite H4 in H. - replace (IZR Z0) with (Rmult PI 0)%R in H by coq_lra. - rewrite /mul /= in H. - apply Rmult_lt_reg_l in H; trivial. - assert (x0 = Z.one). - { - apply lt_IZR in H. - apply lt_IZR in H0. - unfold Z.one. - lia. - } - rewrite H5 /Z.one /IZR /IPR mulr1 in H4. - rewrite H4 cos_PI /one /= in H3. - coq_lra. - } - rewrite /one /= in H3. - coq_lra. - Qed. - -Lemma cos_eq_1 (x : R) : - cos x = 1 -> - exists (k:Z), x = (2 * PI * IZR(k))%R. -Proof. - intros Hx. - assert (PI2_neq0: (2 * PI <> 0)%R). - { - generalize PI_neq0. - unfold mul, one, zero, natmul, add; simpl. - coq_lra. - } - destruct (euclidian_division x (2*PI) PI2_neq0) as (q & r & EQ & Hr & Hr'). - exists q. - rewrite <- (Rplus_0_r (_*_)). subst. - rewrite Rmult_comm. - apply Rplus_eq_compat_l. - rewrite cos_plus in Hx. - assert (H : cos (IZR q * 2 * PI)%R = 1%R) by ( apply cos_eq_1_1; now exists q). - rewrite -Rmult_assoc H /one /= Rmult_1_l sin_eq_0_1 in Hx. - - rewrite Rmult_0_l Rminus_0_r in Hx. - rewrite Rabs_right in Hr'. - + destruct Hr as [Hr | ->]; trivial. - exfalso. - generalize (cos_lt_1 r Hr Hr'); intros. - coq_lra. - + generalize PI_RGT_0; coq_lra. - - exists (Z.mul 2 q). - rewrite mult_IZR. - coq_lra. - Qed. - -Lemma cos_eq_neg1 (x : R) : - cos x = -1 -> - exists k, x = (2 * PI * IZR(k) + PI)%R. -Proof. - intros eqq1. - generalize (Rtrigo_facts.cos_pi_plus x); intros eqq2. - rewrite eqq1 in eqq2. - rewrite Ropp_involutive in eqq2. - apply cos_eq_1 in eqq2. - destruct eqq2 as [k eqq2]. - exists (Z.sub k 1)%Z. - rewrite minus_IZR. - replace (Rplus PI x) with (PI + x)%R in eqq2. - - replace (Rminus (IZR k) 1) with ((IZR k) - 1)%R. - + lra. - + unfold add, opp, one; simpl; coq_lra. - - unfold add; simpl; coq_lra. -Qed. - -Lemma cos_eq_1_nneg (x : R) : - cos x = 1 -> - 0 <= x -> - exists (k:nat), x = (2 * PI * INR(k))%R. -Proof. - intros. - generalize (cos_eq_1 x H); intros. - destruct H1. - rewrite H1 in H0. - replace (IZR Z0) with (2 * PI * 0)%R in H0. - - apply Rmult_le_reg_l in H0. - + unfold zero in H0; simpl in H0. - apply le_IZR in H0. - exists (Z.abs_nat x0). - rewrite H1. - do 2 f_equal. - destruct x0; simpl; trivial; try lia. - now rewrite INR_IPR. - + unfold mul, one, natmul, add; simpl. - generalize PI_RGT_0; coq_lra. - - by rewrite mulr0. -Qed. - -Lemma sin_cos_eq x y: - sin x = sin y /\ cos x = cos y -> - exists (k:Z), - x = (y + 2 * PI * IZR(k))%R. -Proof. - intros. - generalize (cos_minus x y); intros. - destruct H. - rewrite H H1 in H0. - generalize (sin2_cos2 y); intros. - rewrite Rplus_comm in H0. - unfold Rsqr in H2. - rewrite H2 in H0. - apply cos_eq_1 in H0. - destruct H0. - exists x0. - rewrite <- H0. - unfold add; simpl. - coq_lra. -Qed. - -Lemma Pi2_neq0 : - (2 * PI <> 0)%R. -Proof. - generalize PI_neq0. - unfold mul, one, zero, natmul, add; simpl. - coq_lra. -Qed. - -Lemma Pi2_neq0_alt : - is_true (2 * PI != 0). -Proof. - generalize Pi2_neq0. - by case eqP. -Qed. - -End sin_cos. - -Lemma nth_root_eq j k n : - j mod (S n) = k mod (S n) <-> - nth_root j (S n) = nth_root k (S n). -Proof. - split; intros. - - now apply nth_root_mod. - - unfold nth_root in H. - replace (S n) with (addn n 1) in H by lia. - inversion H; clear H. - pose (jj := (2 * PI * INR j)/ (INR (addn n 1))). - pose (kk := (2 * PI * INR k)/ (INR (addn n 1))). - generalize (sin_cos_eq jj kk); intros. - destruct H. - + split; trivial. - + unfold jj, kk in H. - clear H1 H2 jj kk. - replace (2 * PI * INR k / INR (addn n 1) + 2 * PI * IZR x)%R with - (2 * PI * (INR k / INR (addn n 1) + IZR x))%R in H by lra. - replace (2 * PI * INR j / INR (addn n 1))%R with - (2 * PI * (INR j / INR (addn n 1)))%R in H by lra. - apply (f_equal (fun r => (inv (2 * PI)) * r))%R in H. - generalize Pi2_neq0_alt; intros. - rewrite mulrDr -(mulrA _ (INR k) _) !(mulrA _ (2 * PI) _) (mulrC _ (2 * PI)) divff in H; trivial. - rewrite !mul1r in H. - apply (f_equal (fun r => r * (INR (addn n 1)))) in H. - replace (addn n 1) with (S n) in H by lia. - generalize (S_INR_n0 n); intros. - rewrite mulrDl -!mulrA !(mulrC _ (INR (S n))) divff in H; trivial. - rewrite !mulr1 mulrC !INR_IZR_INZ in H. - repeat rewrite <- mult_IZR in H. - repeat rewrite <- plus_IZR in H. - apply eq_IZR in H. - apply Nat2Z.inj. - rewrite !Nat2Z.inj_mod H. - transitivity (Z.modulo (Z.add (Z.of_nat k) (Z.mul x (Z.of_nat (S n)))) (Z.of_nat (S n))). - * by f_equal. - * by rewrite Zdiv.Z_mod_plus_full. -Qed. - -Lemma nth_root_pow_eq (n j k : nat) : - 0%N <> n -> - forall (e1 e2 : nat), - (nth_root j n) ^+ e1 = - (nth_root k n) ^+ e2 <-> - (e1 * j) mod n = (e2 * k) mod n. -Proof. - intros. - destruct n; try lia. - by rewrite !exp_nth_root -nth_root_eq. -Qed. - -Lemma nth_root_1_iff n j : - nth_root j (S n) = 1 <-> j mod (S n) = 0%N. -Proof. - rewrite <- (nth_root_0 n). - rewrite <- nth_root_eq. - replace (0 mod S n) with 0%N; try easy. - rewrite Nat.mod_small; lia. -Qed. - -Lemma nth_root_not_1 j n : - j mod (S n) <> 0%N -> - nth_root j (S n) <> 1. -Proof. - intros ??. - rewrite nth_root_1_iff in H0. - by rewrite H0 in H. -Qed. - -Lemma pow_nth_root_prim n : - exp (nth_root 1 (S n)) (S n) = 1. -Proof. - unfold nth_root. - rewrite de_moivre. - replace (INR n.+1 * (2 * PI * INR 1 / INR n.+1)%R) with (2 * PI)%R. - - by rewrite cos_2PI sin_2PI. - - rewrite [INR 1]INRE mulr1. - rewrite [INR n.+1 * _]mulrC. - rewrite -!mulrA mulVf ?mulr1//. - replace (zero (Field.zmodType R_fieldType)) with (INR 0) by trivial. - by rewrite !INRE ssrnum.Num.Theory.eqr_nat. -Qed. - -Lemma pow_nth_root j n : - exp (nth_root j (S n)) (S n) = 1. -Proof. - by rewrite prim_nth_root -exprM mulnC exprM pow_nth_root_prim expr1n. -Qed. - -Lemma nth_root_mul j k n : - mul (nth_root j (S n)) (nth_root k (S n)) = nth_root (j + k) (S n). -Proof. - intros. - rewrite (prim_nth_root k _). - rewrite (prim_nth_root j _). - rewrite (prim_nth_root (j + k) _). - now rewrite <- exprD. - Qed. - -Lemma nth_root_Sn n : - nth_root (S n) (S n) = 1. -Proof. - by rewrite prim_nth_root nth_root_npow. -Qed. - -Lemma Cinv_r (x : R[i]) : - x <> 0 -> - x * (inv x) = 1. -Proof. - intros. - rewrite divff //. - by case eqP. -Qed. - -Lemma Cinv_l (x : R[i]) : - x <> 0 -> - (inv x) * x = 1. -Proof. - intros. - rewrite mulrC Cinv_r //. -Qed. - -Lemma exp_sub_r (c : R[i]) (n m : nat): - (le m n) -> - c <> 0 -> - exp c (n - m) = (exp c n) / (exp c m). -Proof. - intros. - destruct H. - - rewrite subnn expr0 Cinv_r//. - generalize (@expf_neq0 _ c m). - case: eqP => //. - case: eqP => //. - intuition. - - rewrite expfB//. - case: leP => //. - lia. -Qed. - -Lemma nth_root_diff j k n : - (le j k) -> - (nth_root k (S n)) / (nth_root j (S n)) = nth_root (k-j) (S n). -Proof. - intros. - rewrite (prim_nth_root k _). - rewrite (prim_nth_root j _). - rewrite (prim_nth_root (k-j) _). - rewrite exp_sub_r; trivial. - apply nth_root_not_0. -Qed. - -Lemma nth_root_inv j n : - inv (nth_root j (S n)) = nth_root (S n - (j mod S n)) (S n). -Proof. - generalize (nth_root_diff (j mod S n) (S n) n); intros. - rewrite <- H. - - rewrite nth_root_Sn mul1r. - f_equal. - apply (nth_root_mod j (j mod S n) n). - rewrite Nat.mod_mod; try lia. - - assert (S n <> 0%N) by lia. - generalize (Nat.mod_upper_bound j (S n) H0); lia. - Qed. - -Lemma nth_root_div j k n : - (nth_root j (S n)) / (nth_root k (S n)) = - nth_root (j + (S n - (k mod S n))) (S n). -Proof. - rewrite nth_root_inv. - apply nth_root_mul. -Qed. - -Definition Cmod (x : R[i]) := (* ComplexField.Normc.normc. *) - let: a +i* b := x in sqrt (exp a 2 + exp b 2). - -Lemma nth_root_Cmod j n : - Cmod (nth_root j (S n)) = 1%R. -Proof. - unfold Cmod, nth_root, fst, snd. - rewrite Rplus_comm /one /= -sqrt_1. - f_equal. - by rewrite -!RpowE -!Rsqr_pow2 sin2_cos2. -Qed. - -Lemma Cmod_Cconj_alt (c : R[i]) : - let: a +i* b :=c in - c * (conjc c) = (a^+2 + b^+2) +i* 0. -Proof. - destruct c. - unfold mul; simpl. - f_equal; lra. -Qed. - -Lemma Cmod_Cconj (c : R[i]) : - c * (conjc c) = RtoC (Rsqr (Cmod c)). -Proof. - generalize (Cmod_Cconj_alt c); intros. - unfold Cmod, fst, snd, RtoC. - unfold RtoC in H. - destruct c. - rewrite H. - f_equal. - rewrite -!RpowE Rsqr_sqrt //. - apply Rplus_le_le_0_compat; apply pow2_ge_0. -Qed. - -Lemma nth_root_conj j n : - conjc (nth_root j (S n)) = inv (nth_root j (S n)). -Proof. - generalize (Cmod_Cconj (nth_root j (S n))); intros. - rewrite nth_root_Cmod Rsqr_1 in H. - apply (f_equal (fun c => mul (inv (nth_root j (S n))) c)) in H. - rewrite /RtoC mulr1 mulrA Cinv_l in H. - - now rewrite mul1r in H. - - by apply nth_root_not_0. -Qed. - -Lemma nth_root_conj_alt j n : - conjc (nth_root j (S n)) = nth_root (S n - j mod (S n)) (S n). -Proof. - by rewrite nth_root_conj nth_root_inv. -Qed. - -Lemma nth_root_half_pow_aux n : - exp (nth_root (S n) (2 * (S n))) 2 = 1. -Proof. - replace (muln 2 (S n)) with (S (2 * n + 1)) by lia. - rewrite exp_nth_root. - do 2 replace (muln 2 (S n)) with (S (2 * n + 1)) by lia. - now rewrite nth_root_Sn. -Qed. - -Lemma pow2_inv x y : (x ^+ 2)%R = y -> Rabs x = sqrt y. -Proof. - intros eqq1. - apply (f_equal sqrt) in eqq1. - rewrite expr2 in eqq1. - rewrite -eqq1 -sqrt_Rsqr_abs //. -Qed. - -Lemma Rabs_pm_l x y : Rabs x = y -> x = y \/ (- x)%R = y. -Proof. - unfold Rabs. - destruct (Rcase_abs); [right|left]; rewrite -H //. -Qed. - -Lemma Rabs_pm_r x y : Rabs x = y -> x = y \/ x = (- y)%R. -Proof. - unfold Rabs. - destruct (Rcase_abs); [right|left]; rewrite -H //. - unfold opp; simpl. - by rewrite Ropp_involutive. -Qed. - -Lemma cmult_real (c : R[i]) : - Im (c * c) = 0 <-> - Re c = 0 \/ Im c = 0. -Proof. - destruct c. - simpl. - split; intros. - - assert (Re * Im = 0) by lra. - rewrite /mul /zero /= in H0. - apply Rmult_integral in H0. - by rewrite /zero /=. - - destruct H; rewrite H; lra. -Qed. - -Lemma Cpow_2 (c : R[i]) : - exp c 2 = 1 -> c = 1 \/ c = - 1. -Proof. - rewrite expr2; intros. - assert (Im (c * c) = 0). - { - by rewrite H /=. - } - rewrite cmult_real in H0. - destruct c. - simpl in H0. - destruct H0. - - rewrite H0 /mul /= !mul0r mulr0 !add0r in H. - injection H; intros. - rewrite /mul /one /opp /= in H1. - generalize (pow2_ge_0 Im); intros. - rewrite /pow Rmult_1_r in H2. - coq_lra. - - rewrite H0 /mul /= !mul0r !mulr0 oppr0 !addr0 in H. - rewrite H0. - injection H; intros. - generalize (Rsqr_1_iff Re); intros. - unfold Rsqr in H2. - rewrite /one /mul /= in H1. - rewrite H2 in H1. - destruct H1. - + left. - by rewrite H1 /one /=. - + right. - by rewrite H1 /one /= /opp /= oppr0 /opp /one /= -IZR_NEG. -Qed. - -Lemma nth_root_half_pow n : - nth_root (S n) (2 * (S n)) = -1. -Proof. - generalize (nth_root_half_pow_aux n); intros. - apply Cpow_2 in H. - destruct H; trivial. - replace (muln 2 (S n)) with (S (2 * n + 1)) in H by lia. - generalize (nth_root_not_1 (S n) (2*n+1)); intros. - assert (S n mod S (2 * n + 1) <> 0%N). - { - rewrite Nat.mod_small; lia. - } - tauto. -Qed. - -Lemma pow2_S (j:nat) : - exists (k : nat), expn 2 j = S k. -Proof. - exists (2^j-1)%nat. - lia. -Qed. - -Lemma odd_roots_prim j n : - exp (nth_root (2 * j + 1) (2 ^ (S n))) (2^n) = -1. -Proof. - generalize (pow2_S (S n)); intros. - destruct H. - rewrite H. - rewrite exp_nth_root. - rewrite <- H. - assert ((2 ^ n * (2 * j + 1) mod (2 ^ S n)) = - (2 ^ n mod (2 ^ S n)))%N. - { - replace (2 ^n * (2 * j + 1))%N with (2 ^ n + j*(2 * 2^n))%N by lia. - replace (2 ^ (S n))%N with (2 * 2^n)%N. - - rewrite Nat.mod_add; try lia. - - by rewrite expnS. - } - rewrite H in H0. - apply nth_root_mod in H0. - rewrite <- H in H0. - rewrite H0. - generalize (pow2_S n); intros. - destruct H1. - simpl. - replace (2 ^ n + (2 ^n + 0))%N with (2 * 2^n)%N by lia. - rewrite expnS H1. - apply nth_root_half_pow. -Qed. - -Lemma mult_conj_root j n : - (nth_root j (S n)) * (conjc (nth_root j (S n))) = 1. -Proof. - rewrite nth_root_conj Cinv_r //. - by apply nth_root_not_0. -Qed. - -Lemma nth_root_half n : - nth_root (2 ^n) (2 ^ (S n)) = - 1. -Proof. - destruct (pow2_S (S n)). - generalize (odd_roots_prim 0 n); intros. - rewrite H exp_nth_root -H in H0. - by rewrite muln0 add0n muln1 in H0. -Qed. - -Lemma nth_root_opp j n : - (nth_root j (2 ^ (S n))) + (nth_root (j + 2^n) (2 ^ (S n))) = 0. -Proof. - destruct (pow2_S (S n)). - by rewrite H -nth_root_mul -H nth_root_half mulrN1 addrC addNr. -Qed. - -Definition Nat2Zinj := (Nat2Z.inj_mod, Nat2Z.inj_mul, Nat2Z.inj_add, Nat2Z.inj_sub). - -Lemma inv_pow_nth_root j k : - exp (nth_root j (S k)) k = inv (nth_root j (S k)). -Proof. - rewrite nth_root_inv exp_nth_root -nth_root_eq. - apply Nat2Z.inj. - rewrite !Nat2Zinj. - - rewrite Zdiv.Zmult_mod Zdiv.Zminus_mod. - rewrite Zdiv.Z_mod_same_full Zdiv.Zmod_mod. - rewrite Z.sub_0_l Z.opp_eq_mul_m1 Z.mul_comm. - symmetry. - rewrite Zdiv.Zmult_mod !Zdiv.Zmod_mod. - f_equal. - f_equal. - replace (Zneg xH)%Z with (Z.sub (Z.of_nat k) (Z.of_nat (S k))). - + rewrite Zdiv.Zminus_mod Zdiv.Z_mod_same_full. - f_equal. - rewrite Z.mod_small; lia. - + rewrite Nat2Z.inj_succ; lia. - - generalize (Nat.mod_upper_bound j (S k)); lia. -Qed. - -Lemma conj_pow_nth_root j k : - exp (nth_root j (S k)) k = conjc (nth_root j (S k)). -Proof. - by rewrite nth_root_conj inv_pow_nth_root. -Qed. - - -(* testing notations *) -Definition C0': R[i] := 0. -Definition C1': R[i] := 1. -Definition Cplus' (x y : R[i]) := x + y. -Definition Cmult' (x y : R[i]) := x * y. -Definition Cexp' (x : R[i]) (n : nat) := x ^+ n. -Definition Cdiv' (x y : R[i]) := x / y. -Definition Cinv' (x : R[i]) := x^-1. - diff --git a/coq/FHE/polyinterp.v b/coq/FHE/polyinterp.v deleted file mode 100644 index 620b5ec2..00000000 --- a/coq/FHE/polyinterp.v +++ /dev/null @@ -1,3325 +0,0 @@ -Require Import Reals Permutation Morphisms. -Require Import Coquelicot.Complex. -Require Import List. -Require Import Lra Lia. -Require Import Utils. -Require Import Vector. - -Set Bullet Behavior "Strict Subproofs". - -Lemma Forall2_nth_error_iff {A B} (P:A->B->Prop) (l1 : list A) (l2: list B) : - (forall (i : nat), match nth_error l1 i, nth_error l2 i with - | Some a, Some b => P a b - | None, None => True - | _, _ => False - end - ) <-> - Forall2 P l1 l2. -Proof. - split. - - revert l2; induction l1; destruct l2; simpl in *; trivial; intros HH. - + specialize (HH 0); simpl in HH; contradiction. - + specialize (HH 0); simpl in HH; contradiction. - + constructor. - * now specialize (HH 0); simpl in HH. - * apply IHl1; intros i. - now specialize (HH (S i)). - - induction 1; intros [| i]; simpl; trivial. - apply IHForall2. -Qed. - -Lemma nth_error_eqs {A} (l1 l2 : list A) : - (forall (i : nat), nth_error l1 i = nth_error l2 i) -> - l1 = l2. -Proof. - intros HH. - apply Forall2_eq. - apply Forall2_nth_error_iff; intros i. - rewrite HH. - now destruct (nth_error l2 i). -Qed. - -Lemma nth_error_eqs_len {A} (l1 l2 : list A) : - length l1 = length l2 -> - (forall (i : nat), i < length l1 -> nth_error l1 i = nth_error l2 i) -> - l1 = l2. -Proof. - intros eqq1 HH. - apply nth_error_eqs; intros. - destruct (lt_dec i (length l1)); auto 2. - destruct (nth_error_None l2 i) as [_ eqq2]. - rewrite eqq2 by lia. - apply nth_error_None; lia. -Qed. - -Lemma rev_nth_error {A} (l:list A) (n:nat) : - n < length l -> nth_error (rev l) n = nth_error l (length l - S n). -Proof. - revert n. - induction l using rev_ind; simpl; [lia |]; intros n nlt. - rewrite rev_app_distr, app_length; simpl. - destruct n. - - simpl. - rewrite nth_error_app2 by lia. - now replace (length l + 1 - 1 - length l) with 0 by lia. - - simpl. - rewrite app_length in nlt; simpl in nlt. - rewrite IHl by lia. - rewrite nth_error_app1 by lia. - f_equal. - lia. -Qed. - -Lemma seq_nth_error [len : nat] (start : nat) [n : nat] : - n < len -> nth_error (seq start len) n = Some (start + n). -Proof. - intros nlt. - rewrite (nth_error_nth' _ 0). - - now rewrite seq_nth. - - now rewrite seq_length. -Qed. - -Lemma rev_seq start n : - rev (seq start n) = map (fun i => (n + 2 * start - S i)) (seq start n). -Proof. - apply nth_error_eqs_len. - - now rewrite rev_length, map_length. - - intros i ilt. - rewrite rev_length in ilt. - rewrite rev_nth_error by trivial. - rewrite seq_length in ilt. - rewrite nth_error_map. - repeat rewrite seq_nth_error; trivial - ; rewrite seq_length. - + simpl; f_equal; lia. - + lia. -Qed. - -Lemma map_skipn_S_error {A:Type} (l:list A) n a : - nth_error l n = Some a -> - skipn n l = a :: skipn (S n) l. -Proof. - revert l. - induction n; simpl; destruct l; simpl; intros; try congruence. - rewrite IHn; trivial. -Qed. - -Lemma list_cons_app_hyp_even {A} (P:list A->Prop) (R:A->A->Prop) - (Pnil : P nil) - (Psmoosh : forall a z l, R a z -> R z a -> P l -> P (a::(l ++ z ::nil))) - : forall n (l:list A), length l = 2 * n -> Forall2 R l (rev l) -> P l. -Proof. - induction n; simpl. - - intros [|]; simpl; congruence. - - intros. - destruct l; simpl in *; [lia |]. - destruct l using rev_ind; simpl in *; [lia |]. - rewrite app_length in H; simpl in H. - assert (eqq1:length l = 2 * n) by lia. - specialize (IHn _ eqq1). - rewrite rev_app_distr in H0. - simpl in H0. - invcs H0. - apply Forall2_app_tail_inv in H6. - destruct H6. - auto. -Qed. - -Lemma list_cons_app_hyp_odd {A} (P:list A->Prop) (R:A->A->Prop) - (Psingle : forall a, R a a -> P (a :: nil)) - (Psmoosh : forall a z l, R a z -> R z a -> P l -> P (a::(l ++ z ::nil))) - : forall n (l:list A), length l = 2 * n + 1 -> Forall2 R l (rev l) -> P l. -Proof. - induction n; simpl. - - intros [|]; simpl; try lia. - destruct l; simpl; try lia; intros. - invcs H0. - auto. - - intros. - destruct l; simpl in *; [lia |]. - destruct l using rev_ind; simpl in *; [lia |]. - rewrite app_length in H; simpl in H. - assert (eqq1:length l = 2 * n + 1) by lia. - specialize (IHn _ eqq1). - rewrite rev_app_distr in H0. - simpl in H0. - invcs H0. - apply Forall2_app_tail_inv in H6. - destruct H6. - auto. -Qed. - -Lemma list_cons_app_hyp {A} (P:list A->Prop) (R:A->A->Prop) - (Pnil : P nil) - (Psingle : forall a, R a a -> P (a :: nil)) - (Psmoosh : forall a z l, R a z -> R z a -> P l -> P (a::(l ++ z ::nil))) - : forall (l:list A), Forall2 R l (rev l) -> P l. -Proof. - intros. - destruct (NPeano.Nat.Even_Odd_dec (length l)). - - destruct e. - eapply list_cons_app_hyp_even; eauto. - - destruct o. - eapply list_cons_app_hyp_odd; eauto. -Qed. - -Lemma nth_error_firstn_in {A} (l:list A) n i : - i < n -> - nth_error (firstn n l) i = nth_error l i. -Proof. - revert n l. - induction i; simpl; intros - ; destruct n; destruct l; simpl; trivial; try lia. - apply IHi; lia. -Qed. - -Lemma nth_error_firstn_nin {A} (l:list A) n i : - i >= n -> - nth_error (firstn n l) i = None. -Proof. - intros. - apply nth_error_None. - rewrite firstn_length. - lia. -Qed. - -Lemma nth_error_skipn {A} (l:list A) n i : - nth_error (skipn n l) i = nth_error l (n + i). -Proof. - revert l i. - induction n; simpl; trivial; intros. - destruct l; trivial. - now rewrite nth_error_nil. -Qed. - -Lemma firstn_seq n1 n2 start : - n1 <= n2 -> - firstn n1 (seq start n2) = seq start n1. -Proof. - intros. - apply nth_error_eqs_len. - - rewrite firstn_length; repeat rewrite seq_length. - lia. - - intros i ilt. - rewrite firstn_length, seq_length in ilt. - rewrite nth_error_firstn_in by lia. - now repeat rewrite seq_nth_error by lia. -Qed. - -Lemma skipn_seq n1 n2 start : - skipn n1 (seq start n2) = seq (start+n1) (n2-n1). -Proof. - intros. - apply nth_error_eqs_len. - - now rewrite skipn_length; repeat rewrite seq_length. - - intros i ilt. - rewrite skipn_length in ilt; repeat rewrite seq_length in ilt. - rewrite nth_error_skipn. - repeat rewrite seq_nth_error by lia. - f_equal; lia. -Qed. - -Lemma combine_nth_error [A B : Type] (l : list A) (l' : list B) (n : nat) : - length l = length l' -> nth_error (combine l l') n = match nth_error l n, nth_error l' n with - | Some x, Some y => Some (x,y) - | _, _ => None - end. -Proof. - revert l l'. - induction n; destruct l; destruct l'; simpl; try congruence. - intros. - apply IHn; congruence. -Qed. - - -(* represent complex number as pair *) -Definition nth_root (j n : nat) : C := - let c := (2*PI*INR(j)/INR(n))%R in - (cos c, sin c). - -Lemma S_INR_not_0 n : - INR (S n) <> 0%R. -Proof. - rewrite S_INR. - generalize (pos_INR n). - lra. -Qed. - - -Lemma nth_root_0 n : - nth_root 0 (S n) = R1. -Proof. - unfold nth_root. - assert ((2 * PI * INR 0 / INR (S n))%R = 0%R). - { - unfold INR at 1. - field. - apply S_INR_not_0. - } - rewrite H. - now rewrite cos_0, sin_0. -Qed. - -Lemma nth_root_2PI n j : - nth_root (j * (S n)) (S n) = R1. -Proof. - unfold nth_root. - rewrite mult_INR. - replace (2 * PI * (INR j * INR (S n)) / INR (S n))%R with - (0 + 2 * INR(j) * PI)%R. - - rewrite cos_period, sin_period. - now rewrite cos_0, sin_0. - - field. - apply S_INR_not_0. -Qed. - - -Lemma nth_root_2PI_plus n j k : - nth_root (j + k * (S n)) (S n) = nth_root j (S n). -Proof. - unfold nth_root. - replace (2 * PI * INR (j + k * S n) / INR (S n))%R with - (2 * PI * INR(j)/INR(S n) + 2 * INR k * PI)%R. - - now rewrite cos_period, sin_period. - - rewrite plus_INR, mult_INR. - field. - apply S_INR_not_0. - Qed. - -Definition nth_roots (n:nat) := - map (fun j => nth_root j n) (seq 0 n). - -Lemma de_moive (x : R) (n : nat) : - Cpow (cos x, sin x) n = (cos ((INR n) * x), sin ((INR n) * x)). -Proof. - induction n. - - simpl. - rewrite Rmult_0_l. - now rewrite cos_0, sin_0. - - simpl Cpow. - rewrite IHn. - unfold Cmult, fst, snd. - replace (INR (S n) * x)%R with (x + (INR n) * x)%R. - + rewrite cos_plus, sin_plus. - f_equal. - lra. - + rewrite S_INR. - lra. - Qed. - -Lemma Cpow_nth_root j n e : - Cpow (nth_root j (S n)) e = nth_root (e * j) (S n). -Proof. - unfold nth_root. - rewrite de_moive. - rewrite mult_INR. - do 2 f_equal; field; apply S_INR_not_0. -Qed. - -Lemma Cpow_nth_root_comm j n e : - Cpow (nth_root j (S n)) e = Cpow (nth_root e (S n)) j. -Proof. - do 2 rewrite Cpow_nth_root. - f_equal. - lia. -Qed. - -Lemma nth_root_npow j n : - Cpow (nth_root j (S n)) (S n) = RtoC R1. -Proof. - rewrite Cpow_nth_root. - replace (S n * j) with (j * S n) by lia. - now rewrite nth_root_2PI. -Qed. - -Lemma minus_mod (j1 j2 n : nat) : - j1 mod (S n) = j2 mod (S n) -> - (j2 - j1) mod (S n) = 0. -Proof. - intros eqq1. - destruct (le_dec j1 j2). - - generalize (Zdiv.Zminus_mod (Z.of_nat j2) (Z.of_nat j1) (Z.of_nat (S n))) - ; intros HH. - rewrite <- Nat2Z.inj_sub in HH by trivial. - repeat rewrite <- Nat2Z.inj_mod in HH. - rewrite <- eqq1 in HH. - rewrite Z.sub_diag in HH. - rewrite Zdiv.Zmod_0_l in HH. - apply (f_equal Z.to_nat) in HH. - now rewrite Nat2Z.id in HH. - - rewrite Minus.not_le_minus_0_stt by trivial. - now apply Nat.mod_0_l. -Qed. - -Lemma nth_root_mod j1 j2 n : - j1 mod (S n) = j2 mod (S n) -> - nth_root j1 (S n) = nth_root j2 (S n). -Proof. - intros. - destruct (le_dec j1 j2). - - assert (exists (k:nat), j2 = j1 + k * (S n)). - { - generalize (Nat.div_mod_eq (j2 - j1) (S n)); intros. - exists ((j2 - j1)/(S n)). - rewrite minus_mod in H0; trivial; lia. - } - destruct H0. - rewrite H0. - now rewrite nth_root_2PI_plus. - - assert (exists (k:nat), j1 = j2 + k * (S n)). - { - generalize (Nat.div_mod_eq (j1 - j2) (S n)); intros. - exists ((j1 - j2)/(S n)). - rewrite minus_mod in H0; trivial; lia. - } - destruct H0. - rewrite H0. - now rewrite nth_root_2PI_plus. - Qed. - -Fixpoint list_Cplus (l : list C) : C := - match l with - | nil => R0 - | a :: l' => Cplus a (list_Cplus l') - end. - -Lemma list_Cplus_mult_l c l : - (list_Cplus (map (fun a => c * a) l) = c * list_Cplus l)%C. -Proof. - induction l; simpl. - - ring. - - rewrite IHl. - ring. -Qed. - -Lemma list_cplus_mult_seq_comm c (p : list R) : -list_Cplus - (map (fun '(a0, b) => (RtoC a0 * b)%C) - (combine p (map (fun j : nat => (c ^ j)%C) (seq 1 (length p))))) = - (c * - list_Cplus - (map (fun '(a0, b) => RtoC a0 * b) - (combine p (map (fun j : nat => c ^ j) (seq 0 (length p))))))%C. -Proof. - rewrite <- list_Cplus_mult_l. - rewrite <- seq_shift. - rewrite map_map. - rewrite <- (map_id p) at 1 3. - repeat rewrite combine_map. - repeat rewrite map_map. - f_equal. - apply map_ext; intros [??]; simpl. - ring. -Qed. - -Lemma list_Cplus_Re (l : list C) : - Re (list_Cplus l) = list_sum (map Re l). -Proof. - induction l. - - now simpl. - - simpl. - now rewrite <- IHl. -Qed. - -Lemma prim_nth_root j n : - nth_root j (S n) = Cpow (nth_root 1 (S n)) j. -Proof. - rewrite Cpow_nth_root. - f_equal. - lia. - Qed. - -Lemma nth_root_not_0 j n : - nth_root j (S n) <> R0. -Proof. - unfold nth_root. - unfold RtoC. - generalize cos_sin_0; intros. - specialize (H (2 * PI * INR j / INR (S n))%R). - replace R0 with 0%R by lra. - unfold not. - intros. - apply H. - split. - - apply (f_equal (fun c => fst c)) in H0. - now unfold fst in H0. - - apply (f_equal (fun c => snd c)) in H0. - now unfold snd in H0. - Qed. - - -Lemma cos1_sin0 (x : R) : - cos x = R1 -> - sin x = R0. -Proof. - intros eqq1. - generalize (cos2 x). - rewrite eqq1; intros eqq2. - replace R1 with 1%R in eqq2 by trivial. - rewrite Rsqr_1 in eqq2. - apply Rsqr_0_uniq. - lra. -Qed. - -Lemma cosneg1_sin0 (x : R) : - cos x = Ropp R1 -> - sin x = R0. -Proof. - intros eqq1. - generalize (cos2 x). - rewrite eqq1; intros eqq2. - replace R1 with 1%R in eqq2 by trivial. - rewrite <- Rsqr_neg in eqq2. - rewrite Rsqr_1 in eqq2. - apply Rsqr_0_uniq. - lra. -Qed. - -Lemma cos_eq_1_aux_pos (x : R) : - cos x = R1 -> - exists k, x = (PI * IZR(k))%R. -Proof. - intros eqq1. - generalize (cos1_sin0 _ eqq1); intros eqq2. - apply sin_eq_0_0 in eqq2. - destruct eqq2 as [k eqqk]. - exists k. - lra. -Qed. - -Lemma cos_eq_1_aux_neg (x : R) : - cos x = Ropp R1 -> - exists k, x = (PI * IZR(k))%R. -Proof. - intros eqq1. - generalize (cosneg1_sin0 _ eqq1); intros eqq2. - apply sin_eq_0_0 in eqq2. - destruct eqq2 as [k eqqk]. - exists k. - lra. -Qed. - -Lemma cos_eq_1 (x : R) : - cos x = R1 -> - exists k, x = (2 * PI * IZR(k))%R. -Proof. - intros eqq1. - destruct (cos_eq_1_aux_pos _ eqq1) as [kk eqq2]; subst. - assert (cutter:(forall kk, ((0 <= kk)%Z -> cos (PI * IZR kk) = R1 -> exists k : Z, (PI * IZR kk)%R = (2 * PI * IZR k)%R)) -> (forall kk, (cos (PI * IZR kk) = R1 -> (exists k : Z, (PI * IZR kk)%R = (2 * PI * IZR k)%R - )))). - { - clear. - intros HH kk eqq1. - destruct (Z_le_gt_dec 0 kk); [eauto |]. - destruct (HH (Z.opp kk)%Z). - - lia. - - rewrite opp_IZR. - replace (PI * - IZR kk)%R with (- (PI * IZR kk))%R by lra. - now rewrite cos_neg. - - exists (Z.opp x). - rewrite opp_IZR in H |- *. - lra. - } - - apply cutter; trivial; clear. - intros kk kk_nneg eqq1. - destruct (Zeven_odd_dec kk). - - destruct (Zeven_ex _ z); subst. - exists x. - rewrite mult_IZR. - lra. - - destruct (Zodd_ex _ z); subst. - rewrite plus_IZR, mult_IZR in eqq1. - replace ((PI * (2 * IZR x + 1))%R) with - (2 * IZR x * PI + PI)%R in eqq1 by lra. - rewrite neg_cos in eqq1. - assert (eqq2: cos (2 * IZR x * PI)%R = Ropp R1) by lra. - generalize (cos_period 0 (Z.to_nat x)); intros HH. - rewrite cos_0 in HH. - rewrite INR_IZR_INZ in HH. - rewrite Z2Nat.id in HH by lia. - replace (2 * IZR x * PI)%R with (0 + 2 * IZR x * PI)%R in eqq2 by lra. - lra. -Qed. - -Lemma cos_eq_neg1 (x : R) : - cos x = Ropp R1 -> - exists k, x = (2 * PI * IZR(k) + PI)%R. -Proof. - intros eqq1. - generalize (Rtrigo_facts.cos_pi_plus x); intros eqq2. - rewrite eqq1 in eqq2. - rewrite Ropp_involutive in eqq2. - apply cos_eq_1 in eqq2. - destruct eqq2 as [k eqq2]. - exists (k-1)%Z. - rewrite minus_IZR. - lra. -Qed. - -Lemma cos_eq_1_1 : forall x:R, (exists k : Z, x = (IZR k * 2 * PI)%R) -> cos x = 1%R. -Proof. - assert (forall n, cos (INR n * 2 * PI) = 1%R). { - intros n;induction n as [|n IHn]. - { change (INR 0) with 0%R. - replace (0 * 2 * PI)%R with 0%R by ring. - exact cos_0. } - rewrite S_INR. - replace ((INR n + 1) * 2 * PI)%R with ((INR n) * 2 * PI + 2 * PI)%R by ring. - rewrite cos_plus, IHn, cos_2PI, sin_2PI. - ring. - } - intros x [k Hx]. - rewrite Hx;clear x Hx. - destruct (Z.abs_or_opp_abs k). - - replace (IZR k) with (INR (Z.to_nat k)). - { apply H. } - rewrite INR_IZR_INZ. - f_equal. - apply Z2Nat.id. - lia. - - replace (IZR k) with (- INR (Z.to_nat (- k)))%R. - { replace (- INR (Z.to_nat (- k)) * 2 * PI)%R with (- (INR (Z.to_nat (- k)) * 2 * PI))%R by ring. - rewrite cos_neg. - rewrite H;ring. } - rewrite INR_IZR_INZ. - rewrite <-opp_IZR. f_equal. - lia. -Qed. - -Lemma cos_lt_1 (x : R) : - (0 < x)%R -> - (x < 2*PI)%R -> - (cos x < 1)%R. -Proof. - intros. - generalize (COS_bound x); intros. - generalize PI_RGT_0; intro pi_gt. - destruct H1. - assert (cos x <> 1)%R. - { - unfold not. - intros. - generalize (cos_eq_1_aux_pos x H3); intros. - destruct H4. - rewrite H4 in H0. - rewrite Rmult_comm in H0. - apply Rmult_lt_reg_r in H0; trivial. - rewrite H4 in H. - replace 0%R with (PI * 0)%R in H by lra. - apply Rmult_lt_reg_l in H; trivial. - assert (x0 = 1)%Z. - { - apply lt_IZR in H. - apply lt_IZR in H0. - lia. - } - rewrite H5 in H4. - rewrite Rmult_1_r in H4. - rewrite H4 in H3. - generalize cos_PI; intros. - lra. - } - lra. - Qed. - -Lemma cos_eq_1_alt (x : R) : - cos x = R1 -> - exists (k:Z), x = (2 * PI * IZR(k))%R. -Proof. - intros Hx. - assert (PI2_neq0: (2 * PI <> 0)%R). - { - generalize PI_neq0. - lra. - } - destruct (euclidian_division x (2*PI) PI2_neq0) as (q & r & EQ & Hr & Hr'). - exists q. - rewrite <- (Rplus_0_r (_*_)). subst. - rewrite Rmult_comm. - apply Rplus_eq_compat_l. - rewrite cos_plus in Hx. - assert (H : cos (IZR q * 2 * PI)%R = 1%R) by ( apply cos_eq_1_1; now exists q). - rewrite <- Rmult_assoc in Hx. - rewrite H, Rmult_1_l in Hx. - rewrite sin_eq_0_1 in Hx. - - rewrite Rmult_0_l, Rminus_0_r in Hx. - rewrite Rabs_right in Hr'. - + destruct Hr as [Hr | ->]; trivial. - exfalso. - generalize (cos_lt_1 r Hr Hr'); intros. - lra. - + generalize PI_RGT_0; lra. - - exists (2*q)%Z. - rewrite mult_IZR. - lra. - Qed. - -Lemma cos_eq_1_nneg (x : R) : - cos x = R1 -> - (0 <= x)%R -> - exists (k:nat), x = (2 * PI * INR(k))%R. -Proof. - intros. - generalize (cos_eq_1 x H); intros. - destruct H1. - rewrite H1 in H0. - replace (0%R) with (2 * PI * 0)%R in H0 by lra. - apply Rmult_le_reg_l in H0. - - apply le_IZR in H0. - exists (Z.abs_nat x0). - rewrite H1. - do 2 f_equal. - destruct x0; simpl; trivial; try lia. - now rewrite INR_IPR. - - generalize PI_RGT_0; lra. -Qed. - -Lemma sin_cos_eq x y: - sin x = sin y /\ cos x = cos y -> - exists (k:Z), - x = (y + 2 * PI * IZR(k))%R. -Proof. - intros. - generalize (cos_minus x y); intros. - destruct H. - rewrite H, H1 in H0. - generalize (sin2_cos2 y); intros. - rewrite Rplus_comm in H0. - unfold Rsqr in H2. - rewrite H2 in H0. - apply cos_eq_1 in H0. - destruct H0. - exists x0. - rewrite <- H0. - lra. -Qed. - -Lemma nth_root_eq j k n : - j mod (S n) = k mod (S n) <-> - nth_root j (S n) = nth_root k (S n). -Proof. - split; intros. - - now apply nth_root_mod. - - unfold nth_root in H. - replace (S n) with (n + 1) in H by lia. - inversion H; clear H. - generalize (sin_cos_eq (2 * PI * INR j / INR (n + 1)) - (2 * PI * INR k / INR (n + 1))); intros. - destruct H. - + split; trivial. - + replace (2 * PI * INR k / INR (n + 1) + 2 * PI * IZR x)%R with - (2 * PI * (INR k / INR (n+1) + IZR x))%R in H by lra. - replace (2 * PI * INR j / INR (n + 1))%R with - (2 * PI * (INR j / INR (n + 1)))%R in H by lra. - apply (f_equal (fun r => (/ (2 * PI)) * r))%R in H. - assert (2 * PI <> 0)%R. - { - generalize PI_neq0. - lra. - } - rewrite <- Rmult_assoc in H. - rewrite <- Rinv_l_sym, Rmult_1_l in H; trivial. - rewrite <- Rmult_assoc in H. - rewrite <- Rinv_l_sym, Rmult_1_l in H; trivial. - clear H0 H1 H2. - repeat rewrite plus_INR in H. - simpl in H. - assert (possn:(INR n + 1)%R <> 0%R). - { - generalize (pos_INR n); lra. - } - field_simplify in H; try lra. - apply (f_equal (Rmult (INR n + 1))) in H. - field_simplify in H; try lra. - repeat rewrite INR_IZR_INZ in H. - repeat rewrite <- mult_IZR in H. - repeat rewrite <- plus_IZR in H. - apply eq_IZR in H. - apply Nat2Z.inj. - repeat rewrite Nat2Z.inj_mod. - rewrite H. - transitivity ((Z.of_nat k + (x * (Z.of_nat (S n)))) mod Z.of_nat (S n))%Z. - * f_equal. - rewrite Nat2Z.inj_succ. - lia. - * now rewrite Zdiv.Z_mod_plus_full. -Qed. - -Lemma nth_root_not_1 j n : - j mod (S n) <> 0 -> - nth_root j (S n) <> R1. -Proof. - unfold nth_root. - intros. - unfold RtoC. - unfold not. - intros. - replace (S n) with (n + 1) in H0 by lia. - inversion H0; clear H0. - assert (xnneg :(0 <= 2 * PI * INR j / INR (n + 1))%R). - { - apply Rmult_le_pos. - - generalize (pos_INR j); intros. - apply Rmult_le_pos; trivial. - generalize PI_RGT_0; intros. - lra. - - left. - apply Rinv_0_lt_compat. - apply lt_0_INR. - lia. - } - apply cos_eq_1_nneg in H2; trivial. - destruct H2. - apply (f_equal (fun r => (r /(2 * PI))%R)) in H0. - unfold Rdiv in H0. - rewrite Rmult_comm in H0. - assert ((2 * PI)%R <> R0). - { - generalize PI_neq0; intros. - lra. - } - do 2 rewrite <- Rmult_assoc in H0. - rewrite <- Rinv_l_sym in H0; trivial. - rewrite Rmult_1_l in H0. - symmetry in H0. - rewrite Rmult_comm in H0. - rewrite <- Rmult_assoc in H0. - rewrite <- Rinv_l_sym in H0; trivial. - rewrite Rmult_1_l in H0. - replace (n+1) with (S n) in H0 by lia. - apply (f_equal (fun r => (r * INR (S n))%R)) in H0. - rewrite Rmult_assoc in H0. - rewrite <- Rinv_l_sym in H0. - - rewrite Rmult_1_r in H0. - rewrite <- mult_INR in H0. - apply INR_eq in H0. - apply (f_equal (fun k => k mod (S n))) in H0. - rewrite Nat.mod_mul in H0; try lia. - - apply not_0_INR. - lia. - Qed. - -Lemma nth_root_1 j n : - j mod (S n) = 0 -> - nth_root j (S n) = R1. -Proof. - intros. - rewrite (nth_root_mod j 0 n). - - now rewrite nth_root_0. - - rewrite H. - rewrite Nat.mod_small; lia. -Qed. - -Lemma Cinv_1_r : - Cinv 1%R = 1%R. -Proof. - unfold Cinv. - unfold RtoC. - simpl. - f_equal; field. -Qed. - -Lemma Cpow_sub_r (c : C) (n m : nat): - m <= n -> - c <> R0 -> - (c ^ (n - m))%C = (c ^ n / c ^ m)%C. -Proof. - intros. - assert (Cmult (Cpow c (n - m)) (Cpow c m) = Cpow c n). - { - rewrite <- Cpow_add_r. - f_equal. - lia. - } - rewrite <- H1. - unfold Cdiv. - rewrite <- Cmult_assoc. - rewrite Cinv_r. - - now rewrite Cmult_1_r. - - now apply Cpow_nz. - Qed. - -Lemma nth_root_diff j k n : - j <= k -> - Cdiv (nth_root k (S n)) (nth_root j (S n)) = nth_root (k-j) (S n). -Proof. - intros. - rewrite (prim_nth_root k _). - rewrite (prim_nth_root j _). - rewrite (prim_nth_root (k-j) _). - rewrite Cpow_sub_r; trivial. - apply nth_root_not_0. -Qed. - - -Instance list_Cplus_perm_proper : Proper (@Permutation _ ==> eq) list_Cplus. -Proof. - repeat red. - apply Permutation_ind_bis; trivial; simpl; intros. - - now rewrite H0. - - rewrite H0; ring. - - now rewrite H0, H2. -Qed. - -Lemma C_telescope_mult (c : C) (n : nat) : - (Cmult (c - R1) (list_Cplus (map (fun j => Cpow c j) (seq 0 (S n)))) = - (Cpow c (S n) - 1%R))%C. -Proof. - induction n. - - simpl; ring. - - rewrite seq_S. - simpl plus. - rewrite map_app. - unfold map at 2. - rewrite <- Permutation_cons_append. - unfold list_Cplus; fold list_Cplus. - transitivity ((c - R1) * c ^ S n + (c - R1) * list_Cplus (map (fun j : nat => c ^ j) (seq 0 (S n))))%C; [ring |]. - rewrite IHn. - simpl; ring. -Qed. - -Lemma C_telescope_div (c : C) (n : nat) : - c <> R1 -> - list_Cplus (map (fun j => Cpow c j) (seq 0 (S n))) = - Cdiv (Cpow c (S n) - 1%R) (c - R1). -Proof. - intros. - generalize (C_telescope_mult c n); intros. - rewrite <- H0. - unfold Cdiv. - rewrite Cmult_comm. - rewrite Cmult_assoc. - rewrite Cinv_l. - - now rewrite Cmult_1_l. - - unfold not. - intros. - unfold not in H. - apply H. - apply (f_equal (fun cc => Cplus cc (RtoC R1))) in H1. - now ring_simplify in H1. -Qed. - -Lemma C_telescope_pow_0 (c : C) (n : nat) : - c <> R1 -> - Cpow c (S n) = 1%R -> - list_Cplus (map (fun j => Cpow c j) (seq 0 (S n))) = 0%R. -Proof. - intros. - rewrite C_telescope_div; trivial. - rewrite H0. - field. - simpl. - unfold not; intros. - apply (f_equal (fun cc => (cc + 1%R)%C)) in H1. - now ring_simplify in H1. -Qed. - -Lemma sum_nth_roots_0 n : - list_Cplus (map (fun j => Cpow (nth_root 1 (S (S n))) j) (seq 0 (S (S n)))) = R0. -Proof. - apply C_telescope_pow_0. - - apply nth_root_not_1. - rewrite Nat.mod_1_l; lia. - - now rewrite nth_root_npow. - Qed. - -Lemma sum_nth_roots_0_gen k n : - k mod (S (S n)) <> 0 -> - list_Cplus (map (fun j => Cpow (nth_root k (S (S n))) j) (seq 0 (S (S n)))) = R0. -Proof. - intros. - rewrite C_telescope_div. - - rewrite nth_root_npow. - unfold Cminus. - rewrite Cplus_opp_r. - unfold Cdiv. - now rewrite Cmult_0_l. - - now apply nth_root_not_1. -Qed. - -Lemma pow_nth_root_prim n : - Cpow (nth_root 1 (S n)) (S n) = R1. -Proof. - unfold nth_root. - rewrite de_moive. - replace (INR (S n) * (2 * PI * INR 1 / INR (S n)))%R with (2 * PI)%R. - - now rewrite cos_2PI, sin_2PI. - - replace (INR 1) with R1 by now unfold INR. - field. - apply S_INR_not_0. - Qed. - -Lemma list_Cplus_app l1 l2 : - list_Cplus (l1 ++ l2) = Cplus (list_Cplus l1) (list_Cplus l2). -Proof. - induction l1. - - simpl. - now rewrite Cplus_0_l. - - simpl. - rewrite IHl1. - now rewrite Cplus_assoc. - Qed. - -Lemma root_prod_1 j n : - list_Cplus - (map (fun k => Cmult (Cpow (nth_root j (S n)) k) (Cpow (Cinv (nth_root k (S n))) j)) - (seq 0 (S n))) = INR (S n). -Proof. - replace (map (fun k => Cmult (Cpow (nth_root j (S n)) k) (Cpow (Cinv (nth_root k (S n))) j)) - (seq 0 (S n))) with - (map (fun k => RtoC R1) (seq 0 (S n))). - - induction n. - + simpl. - now rewrite Cplus_0_r. - + rewrite seq_S. - rewrite map_app. - rewrite list_Cplus_app. - rewrite IHn. - replace (S n) with (n + 1) by lia. - replace (S (n + 1)) with (n + 2) by lia. - simpl. - do 2 rewrite plus_INR. - simpl. - unfold RtoC. - rewrite Cplus_0_r. - unfold Cplus, fst, snd. - f_equal; lra. - - apply map_ext. - intros. - rewrite Cpow_inv. - + do 2 rewrite Cpow_nth_root. - replace (j * a) with (a * j) by lia. - rewrite Cinv_r. - * now replace R1 with 1%R by lra. - * apply nth_root_not_0. - + apply nth_root_not_0. - Qed. - -Lemma pow_nth_root j n : - Cpow (nth_root j (S n)) (S n) = R1. -Proof. - rewrite prim_nth_root. - rewrite <- Cpow_mult_r. - replace (j * S n) with (S n * j) by lia. - rewrite Cpow_mult_r. - rewrite pow_nth_root_prim. - now rewrite Cpow_1_l. -Qed. - -Lemma nth_root_mul j k n : - Cmult (nth_root j (S n)) (nth_root k (S n)) = nth_root (j + k) (S n). -Proof. - intros. - rewrite (prim_nth_root k _). - rewrite (prim_nth_root j _). - rewrite (prim_nth_root (j + k) _). - rewrite Cpow_add_r; trivial. - Qed. - -Lemma nth_root_Sn n : - nth_root (S n) (S n) = 1%R. -Proof. - rewrite prim_nth_root. - now rewrite nth_root_npow. -Qed. - -Lemma nth_root_inv j n : - Cinv (nth_root j (S n)) = nth_root (S n - (j mod S n)) (S n). -Proof. - generalize (nth_root_diff (j mod S n) (S n) n); intros. - rewrite <- H. - - rewrite nth_root_Sn. - unfold Cdiv. - rewrite Cmult_1_l. - f_equal. - apply (nth_root_mod j (j mod S n) n). - rewrite Nat.mod_mod; try lia. - - assert (S n <> 0) by lia. - generalize (Nat.mod_upper_bound j (S n) H0); intros. - lia. - Qed. - -Lemma nth_root_div j k n : - Cdiv (nth_root j (S n)) (nth_root k (S n)) = - nth_root (j + (S n - (k mod S n))) (S n). -Proof. - unfold Cdiv. - rewrite nth_root_inv. - apply nth_root_mul. -Qed. - -Lemma nth_root_Cmod j n : - Cmod (nth_root j (S n)) = 1%R. -Proof. - unfold Cmod, nth_root, fst, snd. - rewrite Rplus_comm. - rewrite <- sqrt_1. - f_equal. - do 2 rewrite <- Rsqr_pow2. - now rewrite sin2_cos2. -Qed. - -Lemma Cmod_Cconj (c : C) : - Cmult c (Cconj c) = Rsqr (Cmod c). -Proof. - destruct c. - unfold Cconj, Cmod, Cmult, fst, snd. - rewrite Rsqr_sqrt. - - unfold RtoC. - f_equal; lra. - - apply Rplus_le_le_0_compat; apply pow2_ge_0. -Qed. - -Lemma nth_root_conj j n : - Cconj (nth_root j (S n)) = Cinv (nth_root j (S n)). -Proof. - generalize (Cmod_Cconj (nth_root j (S n))); intros. - rewrite nth_root_Cmod in H. - rewrite Rsqr_1 in H. - apply (f_equal (fun c => Cmult (/ nth_root j (S n)) c)) in H. - rewrite Cmult_1_r in H. - rewrite Cmult_assoc in H. - rewrite Cinv_l in H. - - now rewrite Cmult_1_l in H. - - apply nth_root_not_0. -Qed. - -Definition Ceval_poly (p : list C) (c : C) := - let cpows := map (fun j => Cpow c j) (seq 0 (length p)) in - list_Cplus (map (fun '(a, b) => Cmult a b) (combine p cpows)). - -Definition Ceval_Rpoly (p : list R) (c : C) := - let cpows := map (fun j => Cpow c j) (seq 0 (length p)) in - list_Cplus (map (fun '(a, b) => Cmult (RtoC a) b) (combine p cpows)). - -Fixpoint C_horner_eval (p : list C) (c : C) : C := - match p with - | nil => R0 - | a :: p' => Cplus a (Cmult c (C_horner_eval p' c)) - end. - -Fixpoint C_horner_eval_Rpoly (p : list R) (c : C) : C := - match p with - | nil => R0 - | a :: p' => Cplus a (Cmult c (C_horner_eval_Rpoly p' c)) - end. - -Lemma Ceval_horner_Rpoly (p : list R) (c : C) : - Ceval_Rpoly p c = C_horner_eval_Rpoly p c. -Proof. - induction p. - - now simpl. - - unfold Ceval_Rpoly. - simpl. - rewrite Cmult_1_r. - f_equal. - rewrite <- IHp. - unfold Ceval_Rpoly. - now rewrite list_cplus_mult_seq_comm. -Qed. - -Lemma pow2_S (j:nat) : - exists (k : nat), 2^j = S k. -Proof. - exists (2^j-1). - induction j. - - now simpl. - - simpl. - rewrite IHj. - lia. -Qed. - -Lemma nth_root_half_pow_aux n : - Cpow (nth_root (S n) (2 * (S n))) 2 = 1%R. -Proof. - replace (2 * (S n)) with (S (2 * n + 1)) by lia. - rewrite Cpow_nth_root. - do 2 replace (2 * (S n)) with (S (2 * n + 1)) by lia. - now rewrite nth_root_Sn. -Qed. - -Lemma pow2_inv x y : (x ^ 2)%R = y -> Rabs x = sqrt y. -Proof. - intros eqq1. - apply (f_equal sqrt) in eqq1. - destruct (Rle_dec 0 x). - - intros. - rewrite sqrt_pow2 in eqq1 by trivial. - rewrite eqq1. - rewrite Rabs_right; trivial. - generalize (sqrt_pos y); lra. - - assert (eqq1':sqrt ((-x) ^2) = sqrt y). - { - now rewrite <- Rsqr_pow2, <- Rsqr_neg, Rsqr_pow2. - } - rewrite sqrt_pow2 in eqq1' by lra. - now rewrite Rabs_left by lra. -Qed. - -Lemma Rabs_pm_l x y : Rabs x = y -> x = y \/ (- x)%R = y. -Proof. - unfold Rabs. - destruct (Rcase_abs); [right|left]; lra. -Qed. - -Lemma Rabs_pm_r x y : Rabs x = y -> x = y \/ x = (- y)%R. -Proof. - unfold Rabs. - destruct (Rcase_abs); [right|left]; lra. -Qed. - -Lemma Cpow_2 (c : C) : - Cpow c 2 = 1%R -> c = 1%R \/ c = (-1)%R. -Proof. - unfold Cpow. - rewrite Cmult_1_r. - intros. - destruct c. - unfold Cmult, fst, snd, RtoC in H. - injection H; intros; clear H. - ring_simplify in H0. - apply (f_equal (fun z => (/2 * z)%R)) in H0. - do 2 rewrite <- Rmult_assoc in H0. - rewrite <- Rinv_l_sym in H0; try lra. - rewrite Rmult_1_l, Rmult_0_r in H0. - apply Rmult_integral in H0. - destruct H0; subst; ring_simplify in H1. - - assert (0 <= r0 ^ 2)%R by apply pow2_ge_0. - lra. - - apply pow2_inv in H1. - rewrite sqrt_1 in H1. - apply Rabs_pm_r in H1. - unfold RtoC. - destruct H1; [left|right]; f_equal; lra. -Qed. - -Lemma nth_root_half_pow n : - nth_root (S n) (2 * (S n)) = (-1)%R. -Proof. - generalize (nth_root_half_pow_aux n); intros. - apply Cpow_2 in H. - destruct H; trivial. - replace (2 * (S n)) with (S (2 * n + 1)) in H by lia. - replace 1%R with R1 in H by lra. - generalize (nth_root_not_1 (S n) (2*n+1)); intros. - assert (S n mod S (2 * n + 1) <> 0). - { - rewrite Nat.mod_small; lia. - } - tauto. -Qed. - -Lemma odd_roots_prim j n : - Cpow (nth_root (2 * j + 1) (2 ^ (S n))) (2^n) = (-1)%R. -Proof. - generalize (pow2_S (S n)); intros. - destruct H. - rewrite H. - rewrite Cpow_nth_root. - rewrite <- H. - assert ((2 ^ n * (2 * j + 1) mod (2 ^ S n)) = - (2 ^ n mod (2 ^ S n))). - { - replace (2 ^n * (2 * j + 1)) with (2 ^ n + j*(2 * 2^n)) by lia. - replace (2 ^ (S n)) with (2 * 2^n). - - rewrite Nat.mod_add; try lia. - assert (2^n <> 0). - { - apply Nat.pow_nonzero. - lia. - } - lia. - - simpl. - lia. - } - rewrite H in H0. - apply nth_root_mod in H0. - rewrite <- H in H0. - rewrite H0. - generalize (pow2_S n); intros. - destruct H1. - simpl. - replace (2 ^ n + (2 ^n + 0)) with (2 * 2^n) by lia. - rewrite H1. - now rewrite nth_root_half_pow. -Qed. - -Definition odd_nth_roots (n : nat) := - (map (fun j => nth_root (2*j+1) (2 ^ (S n))) (seq 0 (2^n))). - -Definition V_odd_nth_roots (n : nat) : Vector C (2^n) := - fun j => nth_root (2 * (proj1_sig j) + 1) (2 ^ (S n)). - -Definition V_even_nth_roots (n : nat) : Vector C (2^n) := - fun j => nth_root (2 * (proj1_sig j)) (2 ^ (S n)). - -Definition V_nth_roots_half (n : nat) : Vector C (2^n) := - fun j => nth_root (proj1_sig j) (2 ^ (S n)). - -Definition odd_nth_roots_half (n : nat) := - (map (fun j => nth_root (2*j+1) (2 ^ (S (S n)))) (seq 0 (2^n))). - -Definition decode (p : list R) (n : nat) := - map (C_horner_eval_Rpoly p) (odd_nth_roots (S n)). - -Definition decode_eval (p : list R) (n : nat) := - map (Ceval_Rpoly p) (odd_nth_roots (S n)). - -Definition decode_Ceval (p : list C) (n : nat) := - map (Ceval_poly p) (odd_nth_roots (S n)). - -Definition decode_half (p : list R) (n : nat) := - map (C_horner_eval_Rpoly p) (odd_nth_roots_half n). - -Definition Cinner_prod (l1 l2 : list C) := - list_Cplus (map (fun '(a,b) => Cmult a b) (combine l1 l2)). - -Definition encode (cl : list C) (n : nat) := - let conjroots := map Cconj (odd_nth_roots (S n)) in - map (fun c => Cdiv c (2^(S n))%R) - (map (fun k => Cinner_prod cl (map (fun x => Cpow x k) conjroots)) - (seq 0 (2 ^(S n)))). - -Definition encode_half (cl : list C) (n : nat) := - let conjroots := map Cconj (odd_nth_roots_half n) in - map (fun c => Rdiv c (2^(S n))%R) - (map (fun k => (2 * (Re (Cinner_prod (firstn (2^n) cl) (map (fun x => Cpow x k) conjroots))))%R) - (seq 0 (2 ^(S n)))). - -Lemma Im_div_real (c : C) (r : R) : - r <> 0%R -> - Im c = 0%R <-> Im (Cdiv c r) = 0%R. -Proof. - intros. - destruct c. - unfold Cdiv, Cinv, Cmult, Im, fst, snd. - simpl. - split; intros. - - rewrite H0. - now field. - - field_simplify in H0; try easy. - unfold pow in H0. - field_simplify in H0; try easy. - unfold Rdiv in H0. - apply (f_equal (fun z => (z * r)%R)) in H0. - rewrite Rmult_assoc, Rmult_0_l in H0. - rewrite <- Rinv_l_sym in H0; trivial. - lra. -Qed. - -Lemma Cconj_im_0 (c : C) : - Cconj c = c -> Im c = 0%R. -Proof. - destruct c. - unfold Cconj; simpl. - intros. - injection H; intros. - lra. -Qed. - -Lemma map_inv_rev_even {A} n (l:list A) f (finv: forall x, f (f x) = x): - length l = 2 * n -> - map f l = rev l <-> - map f (firstn n l) = rev (skipn n l). -Proof. - intros llen. - split; intros HH. - - rewrite firstn_skipn_rev. - rewrite llen, map_rev, <- skipn_map, map_rev. - replace (2 * n - n) with n by lia. - now rewrite HH, rev_involutive. - - rewrite <- (firstn_skipn n l). - rewrite map_app, rev_app_distr. - f_equal; trivial. - apply (f_equal (rev (A:=_))) in HH. - repeat rewrite rev_involutive in HH. - rewrite <- HH. - rewrite map_rev, map_map. - now erewrite map_ext; [rewrite map_id |]. -Qed. - -Lemma map_inv_rev_odd {A} n (l:list A) f (finv: forall x, f (f x) = x): - length l = 2 * n + 1 -> - map f l = rev l <-> - map f (firstn n l) = rev (skipn (S n) l) /\ - option_map f (nth_error l n) = (nth_error l n). -Proof. - intros llen. - split; intros HH. - - rewrite firstn_skipn_rev. - rewrite llen, map_rev, <- skipn_map, map_rev. - replace (2 * n + 1 - n) with (S n) by lia. - split. - + now rewrite HH, rev_involutive. - + apply (f_equal (map f )) in HH. - rewrite map_map in HH. - erewrite map_ext in HH; [rewrite map_id in HH|]; trivial. - apply (f_equal (fun x => nth_error x n)) in HH. - rewrite map_rev, rev_nth_error in HH by (rewrite map_length, llen; lia). - rewrite map_length in HH. - replace (length l - S n)%nat with n in HH by lia. - now rewrite nth_error_map in HH. - - destruct HH as [HH1 HH2]. - rewrite <- (firstn_skipn (S n) l) at 2. - rewrite <- (firstn_skipn n l) at 1. - rewrite map_app, rev_app_distr. - f_equal; trivial. - case_eq (nth_error l n). - + intros a ntha. - rewrite (map_skipn_S_error _ _ _ ntha). - apply (f_equal (@rev _)) in HH1. - rewrite rev_involutive in HH1. - rewrite <- HH1. - rewrite map_cons, map_rev, map_map. - erewrite map_ext; [rewrite map_id | auto]. - apply nth_error_split in ntha. - destruct ntha as [l1 [l2 [? lenl1]]]; subst. - repeat rewrite firstn_app. - repeat rewrite rev_app_distr. - replace ((length l1) - length l1) with 0 by lia. - replace (S (length l1) - length l1) with 1 by lia. - rewrite firstn_cons. - repeat rewrite firstn_O. - repeat rewrite firstn_all2 by lia. - simpl; f_equal. - rewrite nth_error_app2 in HH2 by trivial. - replace (length l1 - length l1) with 0 in HH2 by lia. - simpl in HH2. - congruence. - + intros HH. - apply nth_error_None in HH. - assert (n = 0) by lia. - subst. - destruct l; simpl in *; [lia |]. - destruct l; simpl in *; [| lia]. - congruence. -Qed. - -Lemma conj_rev_even n cl : - length cl = 2 * n -> - map Cconj cl = rev cl <-> - map Cconj (firstn n cl) = firstn n (rev cl). -Proof. - intros llen. - generalize (map_inv_rev_even n cl Cconj (Cconj_conj) llen). - rewrite firstn_rev. - now replace (length cl - n) with n by lia. -Qed. - -Lemma conj_rev_odd cl n : - length cl = 2 * n + 1 -> - map Cconj cl = rev cl <-> - (map Cconj (firstn n cl) = rev (skipn (S n) cl) /\ - Im (nth n cl Ci) = 0%R). -Proof. - intros llen. - rewrite (map_inv_rev_odd n cl Cconj (Cconj_conj) llen). - split; intros [? HH]; split; trivial. - - rewrite (nth_error_nth' cl Ci) in HH by lia. - simpl in HH. - invcs HH. - apply Cconj_im_0. - now rewrite Cconj_conj. - - case_eq (nth_error cl n); simpl; trivial; intros. - rewrite (nth_error_nth _ _ _ H0) in HH. - destruct c; unfold Cconj; simpl in *. - rewrite HH. - do 2 f_equal. - lra. -Qed. - -Lemma conj_rev_half (cl_half:list C) : - let cl := cl_half ++ rev (map Cconj cl_half) in - map Cconj cl = rev cl. -Proof. - intros. - pose (n := length cl_half). - assert (length cl = 2*n). - { - unfold cl. - rewrite app_length. - rewrite rev_length. - rewrite map_length. - lia. - } - generalize (conj_rev_even n cl H); intros. - apply H0. - unfold cl. - rewrite firstn_app. - replace (length cl_half) with n by easy. - replace (n - n) with 0 by lia. - simpl. - rewrite rev_app_distr. - rewrite rev_involutive. - rewrite firstn_app. - replace (n - length (map Cconj cl_half)) with 0. - - rewrite map_app. - simpl. - f_equal. - now rewrite firstn_map. - - rewrite map_length. - lia. - Qed. - -Lemma conj_rev_half_conv n (cl:list C) : - length cl = 2*n -> - map Cconj cl = rev cl -> - let cl_half := firstn n cl in - cl = cl_half ++ rev (map Cconj cl_half) . -Proof. - intros. - generalize (conj_rev_even n cl H); intros. - apply H1 in H0. - unfold cl_half. - rewrite H0. - rewrite firstn_rev. - rewrite rev_involutive. - replace (length cl - n) with n by lia. - now rewrite firstn_skipn. -Qed. - -Lemma list_Cplus_rev (cl : list C) : - list_Cplus (rev cl) = list_Cplus cl. -Proof. - apply list_Cplus_perm_proper. - apply Permutation_sym. - apply Permutation_rev. -Qed. - -Lemma list_Cplus_conj (cl : list C) : - list_Cplus (map Cconj cl) = Cconj (list_Cplus cl). -Proof. - induction cl. - - unfold Cconj, fst, snd; simpl. - unfold RtoC; f_equal; lra. - - simpl. - rewrite Cplus_conj. - now f_equal. -Qed. - -Lemma conj_rev_half_sum n (cl : list C) : - length cl = 2*n -> - map Cconj cl = rev cl -> - let sum_cl_half := list_Cplus (firstn n cl) in - list_Cplus cl = (sum_cl_half + Cconj sum_cl_half)%C. -Proof. - intros. - rewrite (conj_rev_half_conv n cl); trivial. - rewrite list_Cplus_app. - unfold sum_cl_half. - f_equal. - rewrite list_Cplus_rev. - now rewrite list_Cplus_conj. -Qed. - -Lemma Cplus_conj (c : C) : - Cplus c (Cconj c) = (2 * Re c)%R. -Proof. - destruct c. - unfold Cplus, Cconj, RtoC, Re, fst, snd. - f_equal; lra. -Qed. - -Lemma conj_rev_half_sum_alt n (cl : list C) : - length cl = 2*n -> - map Cconj cl = rev cl -> - let sum_cl_half := list_Cplus (firstn n cl) in - list_Cplus cl = (2*(Re sum_cl_half))%R. -Proof. - intros. - rewrite (conj_rev_half_sum n); trivial. - now rewrite Cplus_conj. -Qed. - -Lemma conj_rev_rev (cl : list C) : - map Cconj cl = rev cl -> - cl = map Cconj (rev cl). -Proof. - intros. - apply (f_equal (fun l => rev l)) in H. - rewrite rev_involutive in H. - rewrite <- H at 1. - now rewrite map_rev. -Qed. - -Lemma list_Cplus_conj_rev_0 (cl : list C): - length cl < 2 -> - map Cconj cl = rev cl -> - Im (list_Cplus cl) = 0%R. -Proof. - intros. - pose (n := length cl). - destruct cl. - - now simpl. - - destruct cl. - + simpl. - simpl in H0. - inversion H0. - rewrite H2. - apply Cconj_im_0 in H2. - now rewrite Rplus_0_r. - + simpl in H. - lia. -Qed. - -Lemma list_Cplus_conj_rev_recur (n : nat) : - (forall (cl : list C), - length cl = n -> - map Cconj cl = rev cl -> - Im (list_Cplus cl) = 0%R) -> - forall (cl : list C), - length cl = n + 2 -> - map Cconj cl = rev cl -> - Im (list_Cplus cl) = 0%R. -Proof. - intros. - destruct cl; trivial. - simpl in H0. - assert (lcl: length cl = n+1) by lia. - assert(exists (c2 : C) (cl2 : list C), - cl = cl2 ++ (c2 :: nil)). - { - assert (cl <> nil). - { - unfold not; intros. - rewrite H2 in H0. - simpl in H0. - lia. - } - destruct (exists_last H2) as [? [??]]. - exists x0. - now exists x. - } - destruct H2 as [c2 [cl2 ?]]. - rewrite H2. - specialize (H cl2). - simpl. - rewrite list_Cplus_app. - rewrite im_plus. - rewrite H2 in H1. - simpl in H1. - rewrite map_app in H1. - rewrite rev_app_distr in H1. - simpl in H1. - inversion H1. - rewrite Cconj_conj in H5. - apply app_inv_tail in H5. - rewrite H; trivial. - - simpl. - lra. - - rewrite H2 in H0. - rewrite app_length in H0. - simpl in H0. - lia. - Qed. - -Lemma pair_induction (P : nat -> Prop) : - P 0 -> P 1 -> - (forall n, P n -> P (S n) -> P (S (S n))) -> - forall n, P n. -Proof. - intros H0 H1 Hstep n. - enough (P n /\ P (S n)) by easy. - induction n; intuition. -Qed. - -Lemma list_Cplus_conj_rev (cl : list C): - map Cconj cl = rev cl -> - Im (list_Cplus cl) = 0%R. -Proof. - intros HH. - apply (list_cons_app_hyp - (fun cl => Im (list_Cplus cl) = 0%R) - (fun x y => Cconj x = y)). - - trivial. - - simpl; intros. - rewrite Cconj_im_0; trivial. - lra. - - intros. - simpl. - rewrite <- Permutation_cons_append; simpl. - unfold Im in *. - rewrite H1. - rewrite <- H. - simpl; lra. - - rewrite <- HH. - apply Forall2_map_Forall. - rewrite Forall_forall; trivial. -Qed. - -Lemma combine_app {T} (cl1 cl2 cl1' cl2' : list T) : - length cl1 = length cl2 -> - length cl1' = length cl2' -> - combine (cl1 ++ cl1') (cl2 ++ cl2') = combine cl1 cl2 ++ combine cl1' cl2'. -Proof. - revert cl2. - induction cl1; intros; simpl; trivial. - - simpl in H. - symmetry in H. - apply length_zero_iff_nil in H. - rewrite H; now simpl. - - destruct cl2; simpl; [now simpl in H|]. - rewrite IHcl1; trivial. - simpl in H. - lia. - Qed. - -Lemma combine_rev {T} (cl1 cl2 : list T) : - length cl1 = length cl2 -> - combine (rev cl1) (rev cl2) = rev (combine cl1 cl2). -Proof. - revert cl2. - induction cl1; intros ; simpl; trivial. - destruct cl2; simpl. - - now rewrite combine_nil. - - simpl in H. - injection H; intros. - rewrite <- IHcl1; trivial. - rewrite combine_app. - + now simpl. - + now do 2 rewrite rev_length. - + now simpl. - Qed. - -Lemma Cmult_combine_rev (cl1 cl2 : list C) : - length cl1 = length cl2 -> - map (fun '(a, b) => (a * b)%C) (combine (rev cl1) (rev cl2)) = - rev (map (fun '(a, b) => (a * b)%C) (combine cl1 cl2)). -Proof. - intros. - rewrite <- map_rev. - f_equal. - now apply combine_rev. -Qed. - -Lemma Cmult_combine_conv (cl1 cl2 : list C) : - map (fun '(a, b) => (a * b)%C) (combine (map Cconj cl1) (map Cconj cl2)) = - map Cconj (map (fun '(a, b) => (a * b)%C) (combine cl1 cl2)). -Proof. - rewrite combine_map. - do 2 rewrite map_map. - apply map_ext. - intros. - destruct a. - now rewrite Cmult_conj. -Qed. - -Lemma map_mult_conj_rev (cl1 cl2 : list C): - map Cconj cl1 = rev cl1 -> - map Cconj cl2 = rev cl2 -> - length cl1 = length cl2 -> - let cl := map (fun '(a,b) => Cmult a b) (combine cl1 cl2) in - map Cconj cl = rev cl. -Proof. - intros. - assert (combine (map Cconj cl1) (map Cconj cl2) = - combine (rev cl1) (rev cl2)). - { - now rewrite H, H0. - } - apply (f_equal (fun ll => map (fun '(a, b) => (a * b)%C) ll)) in H2. - now rewrite Cmult_combine_rev, Cmult_combine_conv in H2. -Qed. - -Lemma map_pow_conj_rev (cl : list C) (n : nat) : - map Cconj cl = rev cl -> - map Cconj (map (fun c => Cpow c n) cl) = - rev (map (fun c => Cpow c n) cl). -Proof. - intros. - apply (f_equal (fun ll => map (fun cc => Cpow cc n) ll)) in H. - rewrite map_map in H. - rewrite map_rev in H. - rewrite <- H. - rewrite map_map. - apply map_ext. - intros. - now rewrite Cpow_conj. -Qed. - -Lemma map_conj_conj_rev (cl : list C) : - map Cconj cl = rev cl -> - map Cconj (map Cconj cl) = - rev (map Cconj cl). -Proof. - intros. - apply (f_equal (fun ll => map Cconj ll)) in H. - rewrite map_map in H. - rewrite map_rev in H. - rewrite <- H. - rewrite map_map. - now apply map_ext. -Qed. - -Lemma odd_nth_roots_conj_rev n : - let cl := map Cconj (odd_nth_roots (S n)) in - map Cconj cl = rev cl. -Proof. - simpl. - apply map_conj_conj_rev. - unfold odd_nth_roots. - rewrite <- map_rev. - rewrite rev_seq. - do 2 rewrite map_map. - apply map_ext_in; intros. - rewrite plus_0_r. - apply in_seq in H. - destruct H as [_ alt]. - rewrite plus_0_l in alt. - destruct (pow2_S (S (S n))); intros. - rewrite H. - rewrite nth_root_conj. - rewrite nth_root_inv. - f_equal. - rewrite <- H. - replace (2 ^ (S (S n))) with (2 * 2 ^ (S n)) by (simpl; lia). - rewrite Nat.mod_small; lia. -Qed. - -Lemma Cinner_prod_conj_rev (cl1 cl2 : list C) : - length (cl1) = length cl2 -> - map Cconj cl1 = rev cl1 -> - map Cconj cl2 = rev cl2 -> - Im (Cinner_prod cl1 cl2) = 0%R. -Proof. - intros. - unfold Cinner_prod. - apply list_Cplus_conj_rev. - apply map_mult_conj_rev; trivial. -Qed. - -Lemma encode_real (cl : list C) (n : nat): - map Cconj cl = rev cl -> - length cl = length (odd_nth_roots (S n)) -> - forall (x : C), - In x (encode cl n) -> Im x = 0%R. -Proof. - intros. - unfold encode in H1. - apply in_map_iff in H1. - destruct H1 as [? [? ?]]. - apply in_map_iff in H2. - destruct H2 as [? [? ?]]. - assert (Im x0 = 0%R). - { - rewrite <- H2. - apply Cinner_prod_conj_rev; trivial. - - now do 2 rewrite map_length. - - apply map_pow_conj_rev. - apply odd_nth_roots_conj_rev. - } - rewrite <- H1. - apply Im_div_real; trivial. - apply pow_nonzero. - lra. -Qed. - -Lemma Re_Im (c : C) : - Im c = 0%R <-> c = RtoC (Re c). -Proof. - destruct c. - unfold RtoC, Im, Re, fst, snd. - split; intros. - - now rewrite H. - - now inversion H. -Qed. - -Lemma clist_real (cl : list C) : - (forall (x : C), - In x cl -> Im x = 0%R) -> - cl = map RtoC (map Re cl). -Proof. - intros. - rewrite <- List.map_id at 1. - rewrite map_map. - apply map_ext_in. - intros. - specialize (H a H0). - now apply Re_Im. -Qed. - -Lemma encode_real_alt (cl : list C) (n : nat): - map Cconj cl = rev cl -> - length cl = length (odd_nth_roots (S n)) -> - encode cl n = map RtoC (map Re (encode cl n)). -Proof. - intros. - apply clist_real. - now apply encode_real. -Qed. - -Lemma Re_Cmult_R_C (r : R) (c : C) : - Re (Cmult r c) = Rmult r (Re c). -Proof. - destruct c; simpl. - ring. -Qed. - -Lemma Re_Cmult_list_Cplus (r : R) (cl : list C) : - RtoC (Re (Cmult r (list_Cplus cl))) = - Cmult r (list_Cplus (map RtoC (map Re cl))). -Proof. - intros. - induction cl. - - simpl. - unfold Cmult, Re, RtoC, fst, snd. - f_equal; lra. - - simpl. - rewrite Rmult_0_l. - rewrite Rminus_0_r. - rewrite Cmult_plus_distr_l. - rewrite <- IHcl. - rewrite <- RtoC_mult. - rewrite <- RtoC_plus. - f_equal. - rewrite Rmult_plus_distr_l. - f_equal. - now rewrite Re_Cmult_R_C. -Qed. - -Lemma map_commute {A} (l : list A) (f g : A -> A) : - (forall a, f (g a) = g (f a)) -> - map f (map g l) = map g (map f l). -Proof. - intros. - do 2 rewrite map_map. - apply map_ext. - apply H. -Qed. - -Lemma double_Re_commute (c : C) : - (2 * (Re c))%R = Re (2%R * c)%C. -Proof. - destruct c. - unfold RtoC, Re, fst, snd. - simpl. - now ring. -Qed. - -Lemma div_Re_commute (c : C) (r : R) : - (r <> 0)%R -> - (Re c / r)%R = Re (c / r)%C. -Proof. - intros. - destruct c. - unfold RtoC, Re, fst, snd. - simpl. - now field. -Qed. - -Lemma encode_half_correct (cl : list C) (n : nat): - length cl = 2^S n -> - map Cconj cl = rev cl -> - encode_half cl n = map Re (encode cl n). -Proof. - intros. - unfold encode_half, encode. - rewrite map_map. - rewrite map_map. - rewrite map_map. - apply map_ext. - intros. - rewrite double_Re_commute. - rewrite <- div_Re_commute; [| apply pow_nonzero; lra]. - f_equal. - unfold Cinner_prod. - symmetry. - rewrite (conj_rev_half_sum_alt (2 ^ n)). - - rewrite double_Re_commute. - rewrite re_RtoC. - do 3 f_equal. - rewrite firstn_map. - f_equal. - rewrite combine_firstn. - f_equal. - rewrite firstn_map. - f_equal. - unfold odd_nth_roots, odd_nth_roots_half. - do 2 rewrite firstn_map. - do 2 f_equal. - apply firstn_seq. - simpl; lia. - - rewrite map_length. - rewrite combine_length. - do 2 rewrite map_length. - unfold odd_nth_roots. - rewrite map_length. - rewrite seq_length. - rewrite H. - apply Nat.min_id. - - apply map_mult_conj_rev; trivial. - + apply map_pow_conj_rev. - apply odd_nth_roots_conj_rev. - + rewrite map_length. - rewrite map_length. - unfold odd_nth_roots. - rewrite map_length. - now rewrite seq_length. -Qed. - -Lemma encode_half_correct_alt (cl : list C) (n : nat): - length cl = 2^S n -> - map Cconj cl = rev cl -> - map RtoC (encode_half cl n) = encode cl n. -Proof. - intros. - generalize (encode_half_correct cl n H H0); intros. - rewrite (encode_real_alt cl n H0). - - now f_equal. - - unfold odd_nth_roots. - now rewrite map_length, seq_length. -Qed. - -Definition peval_mat (l : list C) := - let n := length l in - map (fun c => map (fun k => Cpow c k) (seq 0 n)) l. - -Definition V_peval_mat {n} (roots : Vector C n) : Matrix C n n := - (fun n1 n2 => Cpow (roots n1) (proj1_sig n2)). - -Definition conj_mat (m : list (list C)) := - map (fun cl => map Cconj cl) m. - -Definition V_conj_mat {n1 n2} (m : Matrix C n1 n2) := - fun n1' n2' => Cconj (m n1' n2'). - -Definition V_inv_mat {n1 n2} (m : Matrix C n1 n2) := - fun n1' n2' => Cinv (m n1' n2'). - -Definition transpose_mat (m : list (list C)) := - let n := length m in - map (fun k => - map (fun cl => nth k cl (RtoC 0%R)) m) - (seq 0 n). - -Definition mat_vec_mult (m : list (list C)) (v : list C) := - map (fun ml => Cinner_prod ml v) m. - -Definition vector_sum {n} (v : Vector C n) := - vector_fold_right Cplus 0%R v. - -Definition V_inner_prod {n} (v1 v2 : Vector C n) := - vector_sum (fun n' => Cmult (v1 n') (v2 n')). - -Definition V_mat_vec_mult {n1 n2} (m : Matrix C n1 n2) (v : Vector C n2) := - fun n' => V_inner_prod (m n') v. - -Definition vec_mat_mult (v : list C) (m : list (list C)) := - map (fun cl => Cinner_prod v cl) (transpose_mat m). - -Definition V_vec_mat_mult {n1 n2} (v : Vector C n1) (m : Matrix C n1 n2) := - fun n' => V_inner_prod v ((transpose m) n'). - -Definition mat_mat_mult (m1 m2 : list (list C)) := - map (fun v => mat_vec_mult m1 v) (transpose_mat m2). - -Definition V_mat_mat_mult {n1 n2 n3} (m1: Matrix C n1 n2) (m2 : Matrix C n2 n3) := - (fun n1' n3' => V_inner_prod (m1 n1') ((transpose m2) n3')). - -(* -Lemma mmv_mult_assoc_row (m2: list (list C)) (r1 v : list C) : - length r1 = length m2 -> - (forall r2, In r2 m2 -> length r2 = length v) -> - Cinner_prod r1 (mat_vec_mult m2 v) = - Cinner_prod (vec_mat_mult r1 m2) v. -Proof. - intros. - unfold mat_vec_mult, vec_mat_mult. - unfold Cinner_prod. - f_equal. -Admitted. -*) - -Definition Vscale {n} (r : C) (v : Vector C n) := - fun n' => Cmult r (v n'). - -Definition Vscale_r {n} (r : C) (v : Vector C n) := - fun n' => Cmult (v n') r. - -Lemma vector_fold_right_Cplus_0 ( v : Vector C 0) : - vector_fold_right Cplus 0%R v = 0%R. -Proof. - unfold vector_fold_right, vector_fold_right_dep, vector_fold_right_bounded_dep. - now simpl. -Qed. - -Lemma vector_sum_scale {n} (c : C) (v : Vector C n) : - Cmult c (vector_sum v) = vector_sum (Vscale c v). -Proof. - unfold vector_sum, Vscale. - induction n. - - do 2 rewrite vector_fold_right_Cplus_0. - now rewrite Cmult_0_r. - - rewrite vector_fold_right_Sn. - rewrite Cmult_plus_distr_l. - rewrite IHn. - now rewrite vector_fold_right_Sn. -Qed. - -Lemma vector_sum_scale_r {n} (v : Vector C n) (c : C) : - Cmult (vector_sum v) c = vector_sum (Vscale_r c v). -Proof. - unfold vector_sum, Vscale. - induction n. - - unfold vector_fold_right, vector_fold_right_dep, vector_fold_right_bounded_dep. - simpl. - now rewrite Cmult_0_l. - - rewrite vector_fold_right_Sn. - rewrite Cmult_plus_distr_r. - rewrite IHn. - now rewrite vector_fold_right_Sn. -Qed. - -Definition vmap' {A B} {n} (f:A->B) (v:Vector A n) : Vector B n - := fun n' => f (v n'). - -Definition mmap' {A B} {m n} (f:A->B) (mat:Matrix A m n) : Matrix B m n - := fun n1' n2' => f (mat n1' n2'). - -Lemma vector_sum_const {n} (c : C) : - vector_sum (ConstVector n c) = Cmult (RtoC (INR n)) c. -Proof. - unfold vector_sum, ConstVector. - induction n. - - rewrite vector_fold_right_Cplus_0. - now rewrite Cmult_0_l. - - rewrite vector_fold_right_Sn. - unfold vlast, vdrop_last. - rewrite S_INR. - replace ((INR n + 1)%R * c)%C with - (c + (INR n) * c)%C. - + f_equal. - rewrite <- IHn. - f_equal. - apply vec_eq_eq; intros x. - now destruct x. - + unfold RtoC. - destruct c. - unfold Cplus, Cmult, fst, snd. - f_equal; lra. - Qed. - -Lemma vector_sum_sum_transpose {n1 n2} (m : Matrix C n1 n2) : - vector_sum (fun n' => vector_sum (m n')) = - vector_sum (fun n'' => vector_sum ((transpose m) n'')). -Proof. - unfold transpose. - unfold vector_sum. - induction n1. - - induction n2. - + now do 2 rewrite vector_fold_right_Cplus_0. - + rewrite vector_fold_right_Cplus_0. - generalize (@vector_sum_const (S n2) (RtoC 0%R)); intros. - rewrite Cmult_0_r in H. - rewrite <- H at 1. - unfold vector_sum. - apply vector_fold_right_ext. - intros ?. - destruct i. - rewrite vector_fold_right_Cplus_0. - now unfold ConstVector. - - induction n2. - + rewrite vector_fold_right_Cplus_0. - generalize (@vector_sum_const (S n1) (RtoC 0%R)); intros. - rewrite Cmult_0_r in H. - rewrite <- H at 2. - unfold vector_sum. - apply vector_fold_right_ext. - intros ?. - rewrite vector_fold_right_Cplus_0. - now unfold ConstVector. - + do 2 rewrite vector_fold_right_Sn. - unfold vlast. - - admit. -Admitted. - - -Lemma V_mmv_mult_assoc {n1 n2 n3} - (m1 : Matrix C n1 n2) - (m2: Matrix C n2 n3) - (v : Vector C n3) : - V_mat_vec_mult m1 (V_mat_vec_mult m2 v) = - V_mat_vec_mult (V_mat_mat_mult m1 m2) v. -Proof. - intros. - unfold V_mat_vec_mult, V_mat_mat_mult. - unfold V_inner_prod, transpose. - apply vec_eq_eq; intros ?. - Admitted. - - -(* -Lemma mmv_mult_assoc (m1 m2: list (list C)) (v : list C) : - (forall r1, In r1 m1 -> length r1 = length m2) -> - (forall r2, In r2 m2 -> length r2 = length v) -> - mat_vec_mult m1 (mat_vec_mult m2 v) = - mat_vec_mult (mat_mat_mult m1 m2) v. -Proof. - intros. - unfold mat_mat_mult. - unfold mat_vec_mult at 1. - replace (map (fun ml => Cinner_prod ml (mat_vec_mult m2 v)) m1) with - (map (fun r1 => - Cinner_prod (vec_mat_mult r1 m2) v) m1). - - unfold vec_mat_mult, mat_vec_mult. - unfold Cinner_prod. - rewrite map_map. - admit. - - apply map_ext_in. - intros. - rewrite mmv_mult_assoc_row; trivial. - now specialize (H a H1). - -Admitted. -*) - -Lemma map_Cmult_combine_comm l1 l2 : - map (fun '(a0, b) => (a0 * b)%C) (combine l1 l2) = - map (fun '(a0, b) => (a0 * b)%C) (combine l2 l1). -Proof. - rewrite combine_swap. - rewrite map_map. - apply map_ext. - intros. - destruct a. - unfold fst, snd. - now rewrite Cmult_comm. -Qed. - -Lemma peval_mat_decode_Ceval (p : list C) n : - length p = length (odd_nth_roots (S n)) -> - decode_Ceval p n = - mat_vec_mult (peval_mat (odd_nth_roots (S n))) p. -Proof. - intros. - unfold decode_Ceval, mat_vec_mult, peval_mat, Ceval_poly. - rewrite map_map. - apply map_ext. - intros. - unfold Cinner_prod. - apply list_Cplus_perm_proper, refl_refl. - rewrite map_Cmult_combine_comm. - now rewrite <- H. -Qed. - -Lemma map_Cinv_l (lc : list C) (c : C) : - c <> 0%R -> - map (fun x => (x / c * c)%C) lc = lc. -Proof. - intros. - replace lc with (map (fun (x : C) => x) lc) at 2. - - apply map_ext. - intros. - unfold Cdiv. - rewrite <- Cmult_assoc. - rewrite Cinv_l; trivial. - now rewrite Cmult_1_r. - - now rewrite map_id. -Qed. - -Lemma nth_map_seq {A} (f : nat -> A) (d : A) (a m : nat) : - In a (seq 0 m) -> - nth a (map f (seq 0 m)) d = f a. -Proof. - intros. - induction m. - - now simpl in H. - - rewrite seq_S. - rewrite map_app. - destruct (lt_dec a m). - + simpl. - rewrite app_nth1. - * rewrite IHm; trivial. - rewrite in_seq; lia. - * now rewrite map_length, seq_length. - + rewrite in_seq in H. - assert (a = m) by lia. - rewrite H0. - simpl. - rewrite app_nth2. - * rewrite map_length, seq_length. - now replace (m - m) with 0 by lia. - * rewrite map_length, seq_length; lia. -Qed. - -Lemma conj_trans_mat_encode (cl : list C) n : - length cl = length (odd_nth_roots (S n)) -> - map (fun c => Cmult c (2^(S n))%R) (encode cl n) = - mat_vec_mult (conj_mat (transpose_mat (peval_mat (odd_nth_roots (S n))))) - cl. - Proof. - intros. - unfold encode. - rewrite map_map. - assert (length (odd_nth_roots (S n)) = 2^(S n)). - { - unfold odd_nth_roots. - now rewrite map_length, seq_length. - } - rewrite map_Cinv_l. - - unfold mat_vec_mult, conj_mat. - rewrite map_map. - unfold transpose_mat. - rewrite map_map. - unfold peval_mat. - rewrite map_length. - rewrite H0. - apply map_ext_in. - intros. - unfold Cinner_prod. - apply list_Cplus_perm_proper, refl_refl. - rewrite map_Cmult_combine_comm. - do 2 f_equal. - do 3 rewrite map_map. - apply map_ext. - intros. - rewrite <- Cpow_conj. - f_equal. - now rewrite nth_map_seq. - - unfold RtoC. - unfold not; intros. - replace (S n) with (n + 1) in H1 by lia. - injection H1; intros. - generalize (pow_nonzero 2 (n+1)); intros. - cut_to H3; lra. - Qed. - -(* -Lemma encode_decode_eval (cl : list C) (n : nat): - map Cconj cl = rev cl -> - length cl = length (odd_nth_roots (S n)) -> - decode_eval (map Re (encode cl n)) n = cl. -Proof. - intros. - unfold encode, decode_eval. - rewrite map_map. - rewrite map_map. - unfold Ceval_Rpoly. - rewrite map_length. - rewrite seq_length. - etransitivity. - - apply map_ext; intros. - apply list_Cplus_perm_proper. - rewrite combine_map. - rewrite combine_self. - repeat rewrite map_map. - reflexivity. - - apply nth_error_eqs; intros. - destruct (lt_dec i (length cl)). - + rewrite nth_error_map. - unfold odd_nth_roots at 2. - rewrite nth_error_map. - rewrite seq_nth_error. - * unfold option_map. - admit. - * unfold odd_nth_roots in H0. - rewrite map_length, seq_length in H0. - congruence. - + assert (length cl <= i) by lia. - apply nth_error_None in H1. - rewrite H1. - apply nth_error_None. - rewrite map_length, <- H0; lia. -Admitted. - -Lemma encode_decode (cl : list C) (n : nat): - map Cconj cl = rev cl -> - length cl = length (odd_nth_roots (S n)) -> - decode (map Re (encode cl n)) n = cl. -Proof. - intros. - rewrite <- encode_decode_eval with (n := n); trivial. - unfold decode, decode_eval. - apply map_ext. - intros. - now rewrite Ceval_horner_Rpoly. -Qed. - *) - - -(* claim (nxn vandermonde on odd roots) x conjugate transpose = n * I. *) - -Lemma mult_conj_root j n : - Cmult (nth_root j (S n)) (Cconj (nth_root j (S n))) = 1%R. -Proof. - rewrite nth_root_conj. - rewrite Cinv_r; trivial. - apply nth_root_not_0. -Qed. - -Lemma nth_root_half n : - nth_root (2 ^n) (2 ^ (S n)) = (-1)%R. -Proof. - destruct (pow2_S (S n)). - generalize (odd_roots_prim 0 n); intros. - replace (2 * 0 +1) with 1 in H by lia. - rewrite H in H0. - rewrite Cpow_nth_root in H0. - rewrite <- H in H0. - now replace (2^n * (2 * 0 + 1)) with (2 ^ n) in H0 by lia. -Qed. - -Lemma nth_root_opp j n : - (nth_root j (2 ^ (S n)) + nth_root (j + 2^n) (2 ^ (S n)) = 0%R)%C. -Proof. - destruct (pow2_S (S n)). - rewrite H. - rewrite <- nth_root_mul. - rewrite <- H. - rewrite nth_root_half. - ring. -Qed. - -Hint Rewrite Nat2Z.inj_mul : of_nat_re. -Hint Rewrite Nat2Z.inj_add : of_nat_re. -Hint Rewrite Nat2Z.inj_sub : of_nat_re. -Hint Rewrite Nat2Z.inj_mod : of_nat_re. -Hint Rewrite Nat2Z.inj_pow : of_nat_re. -Hint Rewrite Nat2Z.inj_div : of_nat_re. -Hint Rewrite Nat2Z.inj_0 : of_nat_re. - -Lemma of_nat_1 : Z.of_nat 1 = 1%Z. -Proof. - lia. -Qed. - -Lemma of_nat_2 : Z.of_nat 2 = 2%Z. -Proof. - lia. -Qed. - -Hint Rewrite of_nat_1 : of_nat_re. -Hint Rewrite of_nat_2 : of_nat_re. - - -Lemma mod_odd_even x y : - y <> 0 -> - Nat.Odd x -> Nat.Even y -> Nat.Odd (x mod y). -Proof. - intros. - rewrite Nat.mod_eq; trivial. - generalize (Nat.Even_mul_l y (x / y) H1); intros HH2. - apply Nat.odd_spec. - rewrite Nat.odd_sub. - - apply Nat.odd_spec in H0. - rewrite H0. - apply Nat.even_spec in HH2. - now rewrite <- Nat.negb_even, HH2. - - now apply Nat.mul_div_le. -Qed. - -Lemma mod_even_even x y : - y <> 0 -> - Nat.Even x -> Nat.Even y -> Nat.Even (x mod y). -Proof. - intros. - rewrite Nat.mod_eq; trivial. - generalize (Nat.Even_mul_l y (x / y) H1); intros HH2. - apply Nat.even_spec. - rewrite Nat.even_sub. - - apply Nat.even_spec in H0. - rewrite H0. - apply Nat.even_spec in HH2. - now rewrite HH2. - - now apply Nat.mul_div_le. -Qed. - -Lemma odd_nth_root_div_pow_sum_0 j k n : - (2*j+1) mod (2^(S n)) <> (2*k+1) mod (2 ^ (S n)) -> - let w := Cdiv (nth_root (2*j+1) (2 ^ (S n))) (nth_root (2*k+1) (2 ^ (S n))) in - list_Cplus (map (fun j => Cpow w j) (seq 0 (2^n))) = 0%R. -Proof. - intros. - destruct (pow2_S n). - rewrite H0. - destruct (pow2_S (S n)). - assert (nz:2 ^ (S n) <> 0) by lia. - apply C_telescope_pow_0. - - unfold w. - rewrite H1. - rewrite nth_root_div. - apply nth_root_not_1. - rewrite <- H1. - intros HH. - apply (f_equal Z.of_nat) in HH. - - autorewrite with of_nat_re in HH. - + rewrite <- Zdiv.Zplus_mod_idemp_r in HH. - rewrite <- Zdiv.Zminus_mod_idemp_l in HH. - rewrite Zdiv.Z_mod_same_full in HH. - rewrite Zdiv.Zminus_mod_idemp_r in HH. - rewrite Zdiv.Zplus_mod_idemp_r in HH. - apply H. - apply Nat2Z.inj. - autorewrite with of_nat_re. - apply (f_equal (fun x => (x + (2 * Z.of_nat k + 1)) mod 2 ^ Z.of_nat (S n)))%Z in HH. - rewrite Zplus_0_l in HH. - rewrite <- HH. - rewrite Zdiv.Zplus_mod_idemp_l. - f_equal. - lia. - + apply Nat.lt_le_incl. - apply Nat.mod_upper_bound. - lia. - - unfold w. - rewrite H1. - rewrite nth_root_div. - rewrite Cpow_nth_root. - apply nth_root_1. - rewrite <- H0, <- H1. - assert (exists (k2 : nat), - (2 ^ S n - (2 * k + 1) mod 2^ S n = 2*k2 + 1)). - { - assert (odd1:Nat.Odd ((2 * k + 1) mod 2 ^ S n)). - { - apply mod_odd_even; trivial; red; eauto. - exists (2 ^ n). - simpl; lia. - } - apply Nat.odd_spec in odd1. - apply Nat.odd_spec. - rewrite Nat.odd_sub. - - rewrite odd1. - now rewrite Nat.odd_pow by congruence. - - apply Nat.lt_le_incl. - apply Nat.mod_upper_bound. - lia. - } - destruct H2. - rewrite H2. - replace (2 * j + 1 + (2 *x1 + 1)) with (2 * (j + x1 + 1)) by lia. - replace (2 ^ n * (2 * (j + x1 + 1))) with - ((j + x1 + 1) * (2 ^ S n)). - + apply Nat.mod_mul; lia. - + simpl; lia. -Qed. - - -Lemma sum_pow_div (c1 c2 : C) n : - c1 = c2 -> - c2 <> 0%R -> - list_Cplus (map (fun j => Cpow (Cdiv c1 c2) j) (seq 0 n)) = INR (n). -Proof. - intros. - replace (map (fun j => Cpow (Cdiv c1 c2) j) (seq 0 n)) with - (map (fun (j:nat) => RtoC 1%R) (seq 0 n)). - - induction n. - + easy. - + rewrite seq_S. - rewrite map_app. - rewrite list_Cplus_app. - rewrite IHn. - rewrite S_INR. - simpl. - rewrite Cplus_0_r. - now rewrite RtoC_plus. - - apply map_ext. - intros. - rewrite H. - unfold Cdiv. - rewrite Cinv_r; trivial. - now rewrite Cpow_1_l. -Qed. - -Lemma odd_nth_root_div_pow_sum_1 j k n : - j mod (2^(S n)) = k mod (2 ^ (S n)) -> - let w := Cdiv (nth_root j (2 ^ (S n))) (nth_root k (2 ^ (S n))) in - list_Cplus (map (fun j => Cpow w j) (seq 0 (2^n))) = INR (2^n). -Proof. - intros. - unfold w. - destruct (pow2_S (S n)). - rewrite H0. - apply sum_pow_div. - - apply nth_root_mod. - now rewrite <- H0. - - apply nth_root_not_0. -Qed. - -Lemma conj_transpose (m : list (list C)) : - conj_mat (transpose_mat m) = transpose_mat (conj_mat m). -Proof. - unfold conj_mat, transpose_mat. - rewrite map_map, map_length. - apply map_ext; intros. - do 2 rewrite map_map. - apply map_ext; intros. - replace (RtoC 0%R) with (Cconj 0%R) at 2. - - now rewrite map_nth. - - unfold Cconj, RtoC, fst, snd. - f_equal; lra. -Qed. - -Lemma V_conj_transpose {n} (m : Matrix C n n) : - V_conj_mat (transpose m) = transpose (V_conj_mat m). -Proof. - unfold V_conj_mat, transpose. - easy. -Qed. - -(* -Lemma transpose_involutive (m : list (list C)) : - transpose_mat (transpose_mat m) = m. -Proof. - unfold transpose_mat. - rewrite map_length. - rewrite seq_length. - apply nth_error_eqs; intros. - rewrite nth_error_map. - destruct (lt_dec i (length m)). - - rewrite seq_nth_error; trivial. - unfold option_map. - rewrite nth_error_nth' with (d := nil); trivial. - f_equal. - rewrite map_map. - - admit. - - unfold option_map. - assert (length m <= i) by lia. - apply nth_error_None in H. - rewrite H. - assert (i >= length (seq 0 (length m))). - { - rewrite seq_length; lia. - } - apply nth_error_None in H0. - now rewrite H0. - Admitted. -*) - -(* -Lemma decode_mat_encode_mat_on_diag (n : nat): - let pmat := (peval_mat (odd_nth_roots (S n))) in - let prod := mat_mat_mult pmat (conj_mat (transpose_mat pmat)) in - forall n, -(* n < length prod -> *) - nth n (nth n prod nil) 0%R = RtoC (2^S n)%R. -Proof. - intros. - unfold prod, mat_mat_mult. - rewrite conj_transpose. - rewrite transpose_involutive. - unfold mat_vec_mult. - unfold pmat, peval_mat, conj_mat. - do 2 rewrite map_map. - -Admitted. -*) - -Lemma V_transpose_involutive {T} {n1 n2} (m : Matrix T n1 n2) : - transpose (transpose m) = m. -Proof. - now unfold transpose. -Qed. - -Lemma nth_root_conj_alt j n : - Cconj (nth_root j (S n)) = nth_root (S n - j mod (S n)) (S n). -Proof. - rewrite nth_root_conj. - now rewrite nth_root_inv. -Qed. - -Lemma vector_sum_list_Cplus {n} (v : Vector C n) : - vector_sum v = list_Cplus (vector_to_list v). -Proof. - unfold vector_sum, vector_to_list. - induction n. - - unfold vector_fold_right. - unfold vector_fold_right_dep. - unfold vector_fold_right_bounded_dep. - now simpl. - - rewrite vector_fold_right_Sn. - rewrite vector_fold_right_Sn. - simpl. - f_equal. - now rewrite IHn. - Qed. - -(* -Lemma list_Cplus_vector_sum (l : list C) : - list_Cplus l = vector_sum (list_to_vector l). -Proof. - unfold list_to_vector. - unfold vector_sum. - induction l. - - unfold vector_fold_right, vector_fold_right_dep, vector_fold_right_bounded_dep. - now simpl. - - simpl. - rewrite IHl. - unfold list_fold_right_dep. - unfold list_fold_right1_dep. - Search vcons. -Admitted. -*) - -Lemma V_telescope_pow_0 (c : C) (n : nat) : - c <> R1 -> - Cpow c (S n) = 1%R -> - vector_sum (fun (j : { j | j < S n}) => Cpow c (proj1_sig j)) = 0%R. -Proof. - intros. - generalize (C_telescope_pow_0 c n H H0); intros. - rewrite <- H1. - rewrite vector_sum_list_Cplus. - apply list_Cplus_perm_proper. - unfold vector_to_list. - assert (vector_fold_right cons nil - (fun j : {j : nat | j < S n} => (c ^ proj1_sig j)%C) = - (rev (map (fun j : nat => (c ^ j)%C) (seq 0 (S n))))). - { - clear H0 H1. - induction n. - - unfold vector_fold_right. - unfold vector_fold_right_dep. - unfold vector_fold_right_bounded_dep. - now simpl. - - rewrite vector_fold_right_Sn. - rewrite seq_S. - rewrite map_app. - rewrite rev_app_distr. - rewrite <- IHn. - simpl. - unfold vlast, proj1_sig. - f_equal. - } - rewrite H2. - symmetry. - apply Permutation_rev. -Qed. - -Lemma nat_mod_mul a b c : - (a * c) <> 0 -> - b mod c = 0 -> - (a * b) mod (a * c) = 0. -Proof. - intros. - generalize (Nat.div_mod_eq b c); intros. - rewrite H0 in H1. - replace (c * (b / c) + 0) with (c * (b / c)) in H1 by lia. - rewrite H1. - replace (a * (c * (b / c))) with ((b/c) * (a * c)) by lia. - now rewrite Nat.mod_mul. -Qed. - -Lemma V_decode_mat_encode_mat_off_diag (n : nat): - let pmat := (V_peval_mat (V_odd_nth_roots (S n))) in - let prod := V_mat_mat_mult pmat (V_conj_mat (transpose pmat)) in - forall i j, - proj1_sig i <> proj1_sig j -> - prod i j = 0%R. -Proof. - intros. - unfold prod. - rewrite V_conj_transpose. - unfold V_mat_mat_mult, V_conj_mat. - rewrite V_transpose_involutive. - unfold pmat. - unfold V_inner_prod. - unfold V_peval_mat. - unfold V_odd_nth_roots. - destruct i. - destruct j. - unfold proj1_sig in *. - destruct (pow2_S (S n)). - rewrite H0. - destruct (pow2_S (S (S n))). - rewrite H1. - generalize (V_telescope_pow_0 (Cmult (nth_root (2 * x + 1) (S x2)) - (Cconj (nth_root (2 * x0 + 1) (S x2)))) - x1 - ); intros. - rewrite <- H2. - - f_equal. - apply vec_eq_eq; intros x3. - simpl. - now rewrite Cpow_mult_l, Cpow_conj. - - rewrite nth_root_conj_alt. - rewrite nth_root_mul. - apply nth_root_not_1. - rewrite <- H1. - assert (2 * x + 1 < 2 ^ (S (S n))). - { - simpl. - simpl in l. - lia. - } - assert (2 * x0 + 1 < 2 ^ (S (S n))). - { - simpl. - simpl in l0. - lia. - } - intros HH. - apply (f_equal Z.of_nat) in HH. - autorewrite with of_nat_re in HH; [|rewrite Nat.mod_small; lia]. - rewrite <- Zdiv.Zplus_mod_idemp_r in HH. - rewrite <- Zdiv.Zminus_mod_idemp_l in HH. - rewrite Zdiv.Z_mod_same_full in HH. - rewrite Zdiv.Zminus_mod_idemp_r in HH. - rewrite Zdiv.Zplus_mod_idemp_r in HH. - apply H. - apply Nat2Z.inj. - autorewrite with of_nat_re. - replace (2 * Z.of_nat x + 1 + (0 - (2 * Z.of_nat x0 + 1)))%Z with - (2 * Z.of_nat x - 2 * Z.of_nat x0)%Z in HH by lia. - apply (f_equal (fun x => (x + (2 * Z.of_nat x0)) mod 2 ^ Z.of_nat (S (S n))))%Z in HH. - rewrite Zplus_0_l in HH. - rewrite Zdiv.Zplus_mod_idemp_l in HH. - replace (2 * Z.of_nat x - 2 * Z.of_nat x0 + 2 * Z.of_nat x0)%Z with - (2 * Z.of_nat x)%Z in HH by lia. - rewrite Z.mod_small in HH. - + rewrite Z.mod_small in HH; try lia. - split; try lia. - apply inj_lt in l0. - rewrite Nat2Z.inj_pow in l0. - replace (2 ^ Z.of_nat (S (S n)))%Z with - (2 * 2 ^ Z.of_nat (S n))%Z; try lia. - rewrite (Nat2Z.inj_succ (S n)). - rewrite Z.pow_succ_r; lia. - + split; try lia. - apply inj_lt in l. - rewrite Nat2Z.inj_pow in l. - replace (2 ^ Z.of_nat (S (S n)))%Z with - (2 * 2 ^ Z.of_nat (S n))%Z; try lia. - rewrite (Nat2Z.inj_succ (S n)). - rewrite Z.pow_succ_r; lia. - - rewrite nth_root_conj_alt. - rewrite nth_root_mul. - generalize nth_root_1; intros. - rewrite Cpow_nth_root. - apply nth_root_1. - rewrite <- H0, <- H1. - replace (2 ^ (S (S n))) with ((2 ^ S n)*2) by (simpl;lia). - rewrite nat_mod_mul; try lia. - rewrite Nat.add_mod; try lia. - assert ((2 * x + 1) mod 2 = 1). - { - rewrite Nat.add_mod; try lia. - replace (2 * x) with (x * 2) by lia. - rewrite Nat.mod_mul; try lia. - now simpl. - } - rewrite H4. - assert (exists (k : nat), - (2 ^ S n * 2 - (2 * x0 + 1) mod (2^ S n * 2) = 2 * k + 1)). - { - assert (odd1:Nat.Odd ((2 * x0 + 1) mod (2 ^ S n * 2))). - { - apply mod_odd_even; trivial; red; eauto; try lia. - exists (2 ^ S n). - simpl; lia. - } - apply Nat.odd_spec in odd1. - apply Nat.odd_spec. - rewrite Nat.odd_sub. - - rewrite odd1. - rewrite Nat.odd_mul. - now rewrite Nat.odd_pow by congruence. - - apply Nat.lt_le_incl. - apply Nat.mod_upper_bound. - lia. - } - destruct H5. - rewrite H5. - replace (2 * x3 + 1) with (1 + x3 * 2) by lia. - replace ((1 + x3 * 2) mod 2) with 1; [now simpl | ]. - rewrite Nat.add_mod; try lia. - rewrite Nat.mod_mul; try lia. - simpl; lia. - Qed. - -Lemma root_conj_power_inv i j n : - Cmult (Cpow (nth_root i (S n)) j) - (Cconj (Cpow (nth_root i (S n)) j)) = 1%R. -Proof. - Search nth_root. - rewrite Cpow_nth_root. - now rewrite mult_conj_root. -Qed. - -Lemma V_decode_mat_encode_mat_on_diag (n : nat): - let pmat := (V_peval_mat (V_odd_nth_roots (S n))) in - let prod := V_mat_mat_mult pmat (V_conj_mat (transpose pmat)) in - forall n0, - prod n0 n0 = RtoC (2^S n)%R. -Proof. - intros. - unfold prod. - rewrite V_conj_transpose. - unfold V_mat_mat_mult, V_conj_mat. - rewrite V_transpose_involutive. - unfold pmat. - unfold V_inner_prod. - replace (fun n' : {n' : nat | n' < 2 ^ S n} => - (V_peval_mat (V_odd_nth_roots (S n)) n0 n' * - Cconj (V_peval_mat (V_odd_nth_roots (S n)) n0 n'))%C) with - (ConstVector (2 ^ S n) (RtoC 1%R)). - - rewrite vector_sum_const. - rewrite Cmult_1_r. - f_equal. - now rewrite pow_INR. - - apply vec_eq_eq. - intros ?. - unfold ConstVector. - unfold V_peval_mat. - unfold V_odd_nth_roots. - destruct (pow2_S (S (S n))). - rewrite H. - now rewrite root_conj_power_inv. - Qed. - -(* -Lemma decode_mat_encode_mat_off_diag (n : nat): - let pmat := (peval_mat (odd_nth_roots (S n))) in - let prod := mat_mat_mult pmat (conj_mat (transpose_mat pmat)) in - forall i j, -(* i < length prod -> - j < length prod -> -*) - i <> j -> - nth i (nth j prod nil) 0%R = 0%R. -Proof. - intros. - unfold prod, mat_mat_mult. - rewrite conj_transpose. - rewrite transpose_involutive. - unfold mat_vec_mult. - unfold pmat, peval_mat, conj_mat. - do 2 rewrite map_map. - - -Admitted. - - -Lemma decode_mat_encode_mat (cl : list C) (n : nat): - length cl = length (odd_nth_roots (S n)) -> - let pmat := (peval_mat (odd_nth_roots (S n))) in - mat_vec_mult (mat_mat_mult pmat (conj_mat (transpose_mat pmat))) cl = - map (fun c => Cmult c (2^(S n))%R) cl. -Proof. - intros. - unfold mat_vec_mult, mat_mat_mult. - rewrite conj_transpose. - rewrite transpose_involutive. - rewrite map_map. - unfold mat_vec_mult. - unfold pmat. - Admitted. -*) - -Lemma vector_sum_all_but_1_0 n (i : {i : nat | i < n}) c : - vector_sum (fun (n' : {n' : nat | n' < n}) => if eq_nat_decide (proj1_sig i) (proj1_sig n') then c else 0%R) = c. -Proof. - unfold vector_sum. - induction n. - - unfold vector_fold_right. - unfold vector_fold_right_dep. - unfold vector_fold_right_bounded_dep. - simpl. - destruct i. - lia. - - rewrite vector_fold_right_Sn. - destruct (Compare_dec.lt_dec (proj1_sig i) n). - + unfold vlast, proj1_sig. - destruct i. - unfold proj1_sig in l. - match_destr. - * apply eq_nat_eq in e. - lia. - * specialize (IHn (exist _ x l)). - rewrite Cplus_0_l. - unfold vdrop_last. - rewrite <- IHn at 1. - apply vector_fold_right_ext. - unfold vec_eq. - intros. - now destruct i. - + unfold vlast, proj1_sig. - destruct i. - unfold proj1_sig in n0. - assert (x = n) by lia. - rewrite H. - unfold vdrop_last. - match_destr. - * replace (vector_fold_right Cplus 0%R - (fun H0 : {n' : nat | n' < n} => - let (x0, _) := H0 in if eq_nat_decide n x0 then c else 0%R))%C - with (RtoC (0%R)). - -- now rewrite Cplus_0_r. - -- assert (RtoC (0%R) = vector_sum (ConstVector n (RtoC (0%R)))). - { - rewrite vector_sum_const. - now rewrite Cmult_0_r. - } - rewrite H0 at 1. - unfold vector_sum, ConstVector. - apply vector_fold_right_ext. - unfold vec_eq. - intros. - destruct i. - match_destr. - apply eq_nat_eq in e0; lia. - * now generalize (eq_nat_refl n); intros. - Qed. - -Lemma index_eq {n} (i j : {n' : nat | n' < n}) : - proj1_sig i = proj1_sig j -> - i = j. -Proof. - intros. - destruct i. - destruct j. - unfold proj1_sig in H. - subst. - apply index_pf_irrel. -Qed. - -Lemma V_decode_mat_encode_mat_assoc_l (n : nat) (cl : Vector C (2^(S n))) : - let pmat := (V_peval_mat (V_odd_nth_roots (S n))) in - let prod := V_mat_mat_mult pmat (V_conj_mat (transpose pmat)) in - V_mat_vec_mult prod cl = Vscale (RtoC (2^S n)%R) cl. -Proof. - generalize (V_decode_mat_encode_mat_on_diag n); intros. - generalize (V_decode_mat_encode_mat_off_diag n); intros. - simpl in H; simpl in H0. - unfold prod, pmat. - apply vec_eq_eq; intros x. - unfold V_mat_vec_mult, Vscale. - unfold V_inner_prod. - generalize (vector_sum_all_but_1_0 (2 ^ S n) x ((RtoC (2 ^ S n)%R) * cl x)); intros. - rewrite <- H1. - f_equal. - apply vec_eq_eq; intros ?. - clear H1. - destruct (eq_nat_decide (proj1_sig x) (proj1_sig i)). - - apply eq_nat_eq in e. - specialize (H x). - generalize (index_eq x i e); intros. - rewrite <- H1. - apply (f_equal (fun z => (z * cl x)%C)) in H. - replace (2 ^ S n)%R with (2 * 2^n)%R by now simpl. - rewrite <- H. - f_equal. - - specialize (H0 x i). - cut_to H0. - + apply (f_equal (fun z => (z * cl i)%C)) in H0. - rewrite Cmult_0_l in H0. - rewrite <- H0. - f_equal. - + unfold not; intros. - now apply eq_eq_nat in H1. -Qed. - -Lemma V_decode_mat_encode_mat (n : nat) (cl : Vector C (2^(S n))) : - let pmat := (V_peval_mat (V_odd_nth_roots (S n))) in - let encmat := (V_conj_mat (transpose pmat)) in - V_mat_vec_mult pmat (V_mat_vec_mult encmat cl) = Vscale (RtoC (2^S n)%R) cl. -Proof. - unfold Vector in cl. - generalize (V_decode_mat_encode_mat_assoc_l n cl); intros. - rewrite <- H. - now rewrite V_mmv_mult_assoc. -Qed. - -Definition diag_matrix {n} (v : Vector C n) : Matrix C n n := - (fun i j => if eq_nat_decide (proj1_sig j) (proj1_sig i) then (v i) else 0%R). - -(* shows evaluation can be done by modified FFT of half size*) -Lemma V_peval_mat_prod (n : nat) : - V_peval_mat (V_odd_nth_roots (S n)) = - V_mat_mat_mult (V_peval_mat (V_even_nth_roots (S n))) - (diag_matrix (V_nth_roots_half (S n))). -Proof. - apply vec_eq_eq; intros ?. - apply vec_eq_eq; intros ?. - unfold V_peval_mat, diag_matrix, V_mat_mat_mult. - unfold V_inner_prod, transpose. - generalize (vector_sum_all_but_1_0 (2 ^ S n) i0 (V_odd_nth_roots (S n) i ^ proj1_sig i0)%C); intros. - rewrite <- H. - f_equal. - apply vec_eq_eq; intros ?. - match_destr. - - unfold V_odd_nth_roots, V_even_nth_roots, V_nth_roots_half. - destruct (pow2_S (S (S n))). - rewrite H0. - do 2 rewrite Cpow_nth_root. - rewrite nth_root_mul. - f_equal. - replace (proj1_sig i1) with (proj1_sig i0); try lia. - now apply eq_nat_eq. - - now rewrite Cmult_0_r. -Qed. - -(* shows enconding can be done by modified IFFT of half size*) -Lemma V_encode_mat_prod (n : nat) : - let pmat := (V_peval_mat (V_odd_nth_roots (S n))) in - let encmat := (V_conj_mat (transpose pmat)) in - encmat = - V_mat_mat_mult - (diag_matrix (vmap' Cconj (V_nth_roots_half (S n)))) - (V_peval_mat (vmap' Cconj (V_even_nth_roots (S n)))). -Proof. - apply vec_eq_eq; intros ?. - apply vec_eq_eq; intros ?. - unfold V_peval_mat, diag_matrix, V_mat_mat_mult, V_conj_mat. - unfold V_inner_prod, transpose. - generalize (vector_sum_all_but_1_0 (2 ^ S n) i (Cconj (V_odd_nth_roots (S n) i0 ^ proj1_sig i))); intros. - rewrite <- H. - f_equal. - apply vec_eq_eq; intros ?. - match_destr. - - assert (eq_nat (proj1_sig i1) (proj1_sig i)). - { - apply eq_nat_eq in e. - apply eq_eq_nat; lia. - } - match_destr; try congruence. - unfold vmap', V_odd_nth_roots, V_even_nth_roots, V_nth_roots_half. - destruct (pow2_S (S (S n))). - rewrite H1. - rewrite <- Cpow_conj. - rewrite <- Cmult_conj. - f_equal. - do 2 rewrite Cpow_nth_root. - rewrite nth_root_mul. - f_equal. - replace (proj1_sig i1) with (proj1_sig i); try lia. - now apply eq_nat_eq. - - assert (~ eq_nat (proj1_sig i1) (proj1_sig i)). - { - unfold not; intros. - apply eq_nat_eq in H0. - rewrite H0 in n0. - now generalize (eq_nat_refl (proj1_sig i)). - } - match_destr; try congruence. - now rewrite Cmult_0_l. -Qed. - -Lemma nth_root_even_half (i n : nat) : - nth_root (2 * i) (2 * (S n)) = nth_root i (S n). -Proof. - unfold nth_root. - do 2 rewrite mult_INR. - generalize (S_INR_not_0 n); intros. - generalize (S_INR_not_0 1); intros. - do 2 f_equal; field; split; lra. -Qed. - -Program Definition index_reflect {n} (n' : {n' : nat | n' < n}) : {n' : nat | n' < n} := - (exist _ (n - 1 - (proj1_sig n')) _). -Next Obligation. - lia. -Qed. - -Lemma index_reflect_involutive {n} (i : {i' : nat | i' < n}) : - index_reflect (index_reflect i) = i. -Proof. - unfold index_reflect. - apply index_eq. - destruct i. - unfold proj1_sig; lia. -Qed. - -Definition vector_rev {n} {T} (v : Vector T n) := - fun i => v (index_reflect i). - -Definition vector_rev_conj {n} (v : Vector C n) := - forall i, - v i = Cconj (v (index_reflect i)). - -Lemma vector_rev_conj_plus {n} (v1 v2 : Vector C n) : - vector_rev_conj v1 -> - vector_rev_conj v2 -> - vector_rev_conj (vmap' (fun '(a,b) => Cplus a b) (vector_zip v1 v2)). -Proof. - unfold vector_rev_conj; intros. - unfold vector_zip, vmap'. - rewrite H, H0. - now rewrite Complex.Cplus_conj. -Qed. - -Lemma vector_rev_conj_mult {n} (v1 v2 : Vector C n) : - vector_rev_conj v1 -> - vector_rev_conj v2 -> - vector_rev_conj (vmap' (fun '(a,b) => Cmult a b) (vector_zip v1 v2)). -Proof. - unfold vector_rev_conj; intros. - unfold vector_zip, vmap'. - rewrite H, H0. - now rewrite Cmult_conj. -Qed. - -Lemma vector_rev_conj_scale {n} (r : R) (v : Vector C n) : - vector_rev_conj v -> - vector_rev_conj (Vscale (RtoC r) v). -Proof. - unfold vector_rev_conj; intros. - unfold Vscale. - rewrite H. - rewrite Cmult_conj. - f_equal. - unfold Cconj, fst, snd, RtoC. - f_equal; lra. -Qed. - -Lemma vector_rev_conj_const_R n (r : R) : - vector_rev_conj (ConstVector n (RtoC r)). -Proof. - unfold vector_rev_conj, ConstVector, RtoC, Cconj, fst, snd; intros. - f_equal; lra. -Qed. - -Lemma vector_rev_conj_conj {n} (v : Vector C n) : - vector_rev_conj v -> - vector_rev_conj (vmap' Cconj v). -Proof. - unfold vector_rev_conj, vmap'; intros. - now rewrite H. -Qed. - -Lemma vector_rev_conj_Cpow {n} i (v : Vector C n) : - vector_rev_conj v -> - vector_rev_conj (vmap' (fun c => Cpow c i) v). -Proof. - unfold vector_rev_conj, vmap'; intros. - rewrite H. - now rewrite Cpow_conj. -Qed. - -Lemma map_Cconj_vector_to_list {n} (v : Vector C n) : - map Cconj (vector_to_list v) = vector_to_list (vmap' Cconj v). -Proof. - unfold vector_to_list. - induction n. - - unfold vector_fold_right, vector_fold_right_dep, vector_fold_right_bounded_dep. - now simpl. - - do 2 rewrite vector_fold_right_Sn. - rewrite map_cons. - rewrite IHn. - f_equal. - Qed. - -Lemma rev_vector_to_list {n} {T} (v : Vector T n) : - rev (vector_to_list v) = vector_to_list (vector_rev v). -Proof. - unfold vector_rev. - unfold vector_to_list. - induction n. - - unfold vector_fold_right, vector_fold_right_dep, vector_fold_right_bounded_dep. - now simpl. - - do 2 rewrite vector_fold_right_Sn. - -Admitted. - -Lemma vector_rev_conj_sum {n} (v : Vector C n) : - vector_rev_conj v -> - Im (vector_sum v) = 0%R. -Proof. - rewrite vector_sum_list_Cplus. - intros. - apply list_Cplus_conj_rev. - rewrite map_Cconj_vector_to_list. - rewrite rev_vector_to_list. - f_equal. - apply vec_eq_eq; intros ?. - assert (vector_rev_conj (vmap' Cconj v)). - { - now apply vector_rev_conj_conj. - } - rewrite H0. - unfold vmap'. - now rewrite Cconj_conj. -Qed. - -Lemma vector_rev_conj_inner {n} (v1 v2 : Vector C n) : - vector_rev_conj v1 -> - vector_rev_conj v2 -> - Im (V_inner_prod v1 v2) = 0%R. -Proof. - intros. - apply vector_rev_conj_sum. - now apply vector_rev_conj_mult. -Qed. - -Lemma vector_cplus_comm {n} (v1 v2 : Vector C n) : - (vmap' (fun '(a,b) => Cplus a b) (vector_zip v1 v2)) = - (vmap' (fun '(a,b) => Cplus a b) (vector_zip v2 v1)). -Proof. - unfold vmap', vector_zip. - apply vec_eq_eq; intros ?. - apply Cplus_comm. -Qed. - -Lemma vector_cmult_comm {n} (v1 v2 : Vector C n) : - (vmap' (fun '(a,b) => Cmult a b) (vector_zip v1 v2)) = - (vmap' (fun '(a,b) => Cmult a b) (vector_zip v2 v1)). -Proof. - unfold vmap', vector_zip. - apply vec_eq_eq; intros ?. - apply Cmult_comm. -Qed. - -Lemma vector_cplus_assoc {n} (v1 v2 v3 : Vector C n) : - vmap' (fun '(a,b) => Cplus a b) (vector_zip v1 (vmap' (fun '(a,b) => Cplus a b) (vector_zip v2 v3))) = - vmap' (fun '(a,b) => Cplus a b) (vector_zip (vmap' (fun '(a,b) => Cplus a b) (vector_zip v1 v2)) v3). -Proof. - unfold vmap', vector_zip. - apply vec_eq_eq; intros ?. - apply Cplus_assoc. -Qed. - -Lemma vector_cmult_assoc {n} (v1 v2 v3 : Vector C n) : - vmap' (fun '(a,b) => Cmult a b) (vector_zip v1 (vmap' (fun '(a,b) => Cmult a b) (vector_zip v2 v3))) = - vmap' (fun '(a,b) => Cmult a b) (vector_zip (vmap' (fun '(a,b) => Cmult a b) (vector_zip v1 v2)) v3). -Proof. - unfold vmap', vector_zip. - apply vec_eq_eq; intros ?. - apply Cmult_assoc. -Qed. - -Lemma vector_rev_conj_odd_nth_roots (n : nat) : - vector_rev_conj (V_odd_nth_roots (S n)). -Proof. - unfold vector_rev_conj, V_odd_nth_roots. - intros. - destruct i. - unfold index_reflect, proj1_sig. - destruct (pow2_S (S (S n))). - rewrite H. - rewrite nth_root_conj_alt. - f_equal. - rewrite <- H. - replace (2^S (S n)) with (2 * 2^S n) by (simpl; lia). - rewrite Nat.mod_small; lia. -Qed. - -Lemma V_mat_encode_real (n : nat) (cl : Vector C (2^(S n))) : - let pmat := (V_peval_mat (V_odd_nth_roots (S n))) in - let encmat := (V_conj_mat (transpose pmat)) in - vector_rev_conj cl -> - forall i, - Im ((V_mat_vec_mult encmat cl) i) = 0%R. -Proof. - unfold V_mat_vec_mult, transpose, V_peval_mat, V_conj_mat. - intros. - apply vector_rev_conj_inner; trivial. - apply (vector_rev_conj_conj (vmap' (fun c => Cpow c (proj1_sig i)) (V_odd_nth_roots (S n)))). - apply vector_rev_conj_Cpow, vector_rev_conj_odd_nth_roots. -Qed. - -Lemma V_mat_encode_real_alt (n : nat) (cl : Vector C (2^(S n))) : - let pmat := (V_peval_mat (V_odd_nth_roots (S n))) in - let encmat := (V_conj_mat (transpose pmat)) in - vector_rev_conj cl -> - V_mat_vec_mult encmat cl = vmap' RtoC (vmap' Re (V_mat_vec_mult encmat cl)). -Proof. - intros. - apply vec_eq_eq; intros ?. - apply Re_Im. - now apply V_mat_encode_real. -Qed. - - diff --git a/coq/FHE/zp_prim_root.v b/coq/FHE/zp_prim_root.v deleted file mode 100644 index a8c81da0..00000000 --- a/coq/FHE/zp_prim_root.v +++ /dev/null @@ -1,3070 +0,0 @@ -Require Import Lia. -From mathcomp Require Import all_ssreflect zmodp poly ssralg cyclic fingroup finalg ring seq bigop. -Require Import encode. - -Set Implicit Arguments. -Unset Strict Implicit. -Unset Printing Implicit Defensive. -Set Bullet Behavior "Strict Subproofs". - -(* next 3 lemmas are copied from mathcomp_extra rsa *) - -(* This should be part of the standard library *) - -Lemma prime_modn_expSn p n : prime p -> n.+1 ^ p = (n ^ p).+1 %[mod p]. -Proof. -case: p => // p pP. -rewrite -[(_ ^ _).+1]addn0 (expnDn 1) big_ord_recr big_ord_recl /=. -rewrite subnn binn exp1n !mul1n addnAC -modnDmr; congr ((_ + _) %% _). -apply/eqP/dvdn_sum => -[i ?] _; exact/dvdn_mulr/prime_dvd_bin. -Qed. - -(* This should be part of the standard library *) - -Lemma fermat_little a p : prime p -> a ^ p = a %[mod p]. -Proof. -move=> pP. -elim: a => [|a IH]; first by rewrite exp0n // prime_gt0. -by rewrite prime_modn_expSn // -addn1 -modnDml IH modnDml addn1. -Qed. - -(* This should be part of the standard library *) - -Lemma fermat_little_pred a p : prime p -> ~(p %| a) -> a ^ p.-1 = 1 %[mod p]. -Proof. -move=> Pp pNDa. -have a_gt0 : 0 < a by case: a pNDa. -have aCp : coprime a p by rewrite coprime_sym prime_coprime //; apply/negP. -have aE : (egcdn a p).1 * a = 1 %[mod p]. - by case: egcdnP => //= km kn -> _; rewrite (eqP aCp) modnMDl. -rewrite -[_^ _]muln1 -modnMmr -[in LHS]aE // modnMmr. -rewrite mulnC -mulnA -expnS prednK ?prime_gt0 //. -by rewrite -modnMmr fermat_little // modnMmr aE. -Qed. - -Import ssralg.GRing. - -Section cyclic. -Local Open Scope ring_scope. - -Variable p : nat. - -Lemma prime_pbig2 (q : nat) : - prime q -> - 0 < (q.-1)%N. -Proof. - move=> p_prime. - move: (p_prime) => /(prime_gt0). - destruct q; [| destruct q]. - - by rewrite ltnn. - - by inversion p_prime. - - by rewrite ltn0Sn. -Qed. - - Lemma Fp_exp_expn (x : 'F_p) (n : nat): - prime p -> - nat_of_ord (x ^+ n)%R = x ^ n %% p. - Proof. - move=> p_prime. - have peqq := (Fp_cast p_prime). - induction n. - - by rewrite /= {1}peqq expn0. - - rewrite expnS exprS /mul /= IHn -modnMm. - rewrite {2 4 5}peqq. - by rewrite -(modnMm x _) modn_mod. - Qed. - - Lemma zp_prim_root_max : - prime p -> - { w : 'F_p | (p.-1).-primitive_root w}. - Proof. - move=> p_prime. - have pbig := (prime_pbig2 p_prime). - have/(nth_find 0)-HH: has (p.-1).-primitive_root [seq x <- enum 'F_p | x != Zp0]. - { - apply (@has_prim_root _ _ [seq x <- enum 'F_p | x != Zp0]) => //=. - - rewrite all_filter. - apply/allP => /= x xin. - apply/implyP=> xn0. - rewrite unity_rootE. - have/(fermat_little_pred p_prime)-eqq: ~ p %| x. - { - have xltp: x < p. - { - move: (ltn_ord x). - by rewrite {2}(Fp_cast p_prime). - } - have xpos: (0 < x) by by rewrite lt0n. - move/(dvdn_leq xpos); lia. - } - apply /eqP. - apply val_inj => /=. - by rewrite {2}(Fp_cast p_prime) -eqq Fp_exp_expn. - - apply filter_uniq. - apply enum_uniq. - - rewrite -rem_filter; [| apply enum_uniq]. - rewrite size_rem. - + by rewrite -cardE (card_Fp p_prime). - + by rewrite enumT Finite.EnumDef.enumDef /= mem_ord_enum. - } - by exists ([seq x <- enum 'F_p | x != 0]`_(find (p.-1).-primitive_root [seq x <- enum 'F_p | x != 0])). - Qed. - - Lemma zp_prim_root (n : nat) : - n > 0 -> - prime p -> - n %| p.-1 -> - { w : 'F_p | n.-primitive_root w}. - Proof. - intros npos prim div. - destruct (zp_prim_root_max prim). - generalize (dvdn_prim_root i div); intros. - by exists (exp x (p.-1 %/ n)). - Qed. - -Lemma inZp_add j k n : - inZp (j + k) = inZp j + inZp k :> 'Z_n. -Proof. - apply: val_inj => /=. - by rewrite modnDm. -Qed. - -Lemma inZp_mul j k n : - inZp (j * k) = inZp j * inZp k :> 'Z_n. -Proof. - apply: val_inj => /=. - by rewrite modnMm. -Qed. - -Lemma inZp_exp j k n : - inZp (j ^ k) = inZp j ^+ k :> 'Z_n. -Proof. - induction k. - - rewrite expn0 expr0. - by apply: val_inj => /=. - - rewrite exprS expnS -IHk. - apply inZp_mul. -Qed. - -End cyclic. - -Require Import ssrbool. - - -Section chinese. - - (* pairs represent (residue, modulus) *) - Fixpoint chinese_list (l : seq (nat * nat)) : nat := - match l with - | nil => 0 - | a :: nil => a.1 - | a :: l' => - chinese (a.2) (\prod_(i <- (map snd l')) i) - (a.1) (chinese_list l') - end. - - Lemma all_coprime_prod (a : nat) (l : seq nat) : - all (coprime a) l -> - coprime a (\prod_(i <- l) i). - Proof. - intros. - rewrite big_seq. - apply big_rec. - - apply coprimen1. - - intros. - move/allP/(_ _ H0): H. - by rewrite coprimeMr H1 => ->. - Qed. - - Lemma chinese_remainder_list_cons_l (a : nat * nat) (l : list (nat * nat)) : - all (coprime a.2) (map snd l) -> - chinese_list (a::l) == a.1 %[mod a.2]. - Proof. - induction l=> //= HH. - rewrite big_cons. - destruct l; trivial. - - rewrite chinese_modl //. - rewrite big_nil muln1. - by move/andP: HH => [-> _]. - - rewrite chinese_modl //. - move/andP: HH => [HH1 HH2/=]. - by rewrite coprimeMr HH1 /= all_coprime_prod. - Qed. - - Lemma pairwise_coprime_cons a l : - pairwise coprime (a :: l) -> - all (coprime a) l. - Proof. - simpl. - move /andP. - tauto. - Qed. - - Lemma chinese_reminder_list_cons_r (a : nat * nat) (l : list (nat * nat)) : - pairwise coprime (map snd (a::l)) -> - let m := \prod_(i <- map snd l) i in - chinese_list (a::l) == chinese_list l %[mod m]. - Proof. - intros. - simpl. - destruct l; trivial. - - rewrite big_nil /= modn1 //. - - rewrite chinese_modr //. - apply all_coprime_prod. - by apply pairwise_coprime_cons. - Qed. - - Lemma symmetricE {A} (f:A->A->bool) : (ssrbool.symmetric f) <-> (RelationClasses.Symmetric f). - Proof. - rewrite /symmetric /RelationClasses.Symmetric. - split; intros. - - by rewrite H. - - case_eq (f x y)=> eqq. - + by rewrite H. - + case_eq (f y x)=> eqq2//. - by rewrite (H _ _ eqq2) in eqq. - Qed. - - Lemma allE {A} (P:pred A) (l:list A) : all P l <-> List.Forall (fun x : A => P x) l. - Proof. - elim: l => /=. - - split=>// _. - - move=> a l IHl. - split. - + move/andP=> [ap pairP]. - constructor; tauto. - + inversion 1; subst. - by rewrite H2 IHl. - Qed. - - Lemma pairwiseE {A} (P:rel A) (l:list A) : pairwise P l <-> List.ForallOrdPairs P l. - Proof. - elim: l => /=. - - split=>// _. - constructor. - - move=> a l IHl. - split. - + move/andP=> [ap pairP]. - constructor;[| tauto]. - by apply allE. - + inversion 1; subst. - apply/andP; split. - * by apply allE. - * by apply IHl. - Qed. - - Lemma chinese_remainder_list_split : - forall (l1 l2 l : list (nat * nat)) (p : nat * nat), - pairwise coprime (map snd l) -> - l = l1 ++ (p::nil) ++ l2 -> - chinese_list l == p.1 %[mod p.2]. - Proof. - induction l1; intros. - - simpl in H0. - rewrite H0. - rewrite H0 in H. - rewrite chinese_remainder_list_cons_l //. - simpl in H. - by move/andP: H => [-> _]. - - pose (l3 := l1 ++ [:: p] ++ l2). - have ->/=: l = a :: l3 by by rewrite H0. - case_eq l3. - + subst l3; destruct l1; simpl; congruence. - + move=> _ _ <-. - have pc': pairwise coprime [seq i.2 | i <- l3]. - { - rewrite H0/= in H. - by move/andP: H => [_ ->]. - } - specialize (IHl1 l2 l3 p pc' (Logic.eq_refl _)). - rewrite <- (eqP IHl1). - have cp: coprime (a.2) (\prod_(i <- (map snd l3)) i). - - { - rewrite all_coprime_prod //. - subst l3. - simpl in H. - rewrite H0/= in H. - by move/andP: H => [-> _]. - } - move/eqP: (chinese_modr cp (a.1) (chinese_list l3)). - have ->: \prod_(i <- [seq i.2 | i <- l3]) i = \prod_(i <- [seq i.2 | i <- p::(l1++l2)]) i. - { - apply perm_big_AC. - - apply mulnA. - - apply mulnC. - - subst l3. - apply perm_map. - by rewrite perm_catCA perm_refl. - } - simpl. - rewrite big_cons. - by apply modn_muln. - Qed. - - Lemma chinese_remainder_list (l : list (nat * nat)) : - pairwise coprime (map snd l) -> - forall p, - p \in l -> - chinese_list l == p.1 %[mod p.2]. - Proof. - intros. - case/splitPr: H0 H=>l1 l2 HH. - by eapply chinese_remainder_list_split => //=. - Qed. - - Lemma chinese_remainder_list_unique (a b : nat) (l : list nat) : - pairwise coprime l -> - (forall p, - p \in l -> a == b %[mod p]) -> - a == b %[mod \prod_(i <- l) i]. - Proof. - induction l; simpl; intros. - - by rewrite big_nil !modn1. - - move /andP in H. - destruct H. - rewrite big_cons chinese_remainder. - + apply /andP. - split. - * apply H0. - rewrite in_cons. - apply /orP. - by left. - * apply IHl; trivial. - intros. - apply H0. - rewrite in_cons. - apply /orP. - by right. - + by apply all_coprime_prod. - Qed. - - Definition balanced_chinese_list (l : seq (nat * nat)) := - \sum_(p <- l) - (\prod_(q <- l | q != p) q.2) * - ((p.1 * (egcdn (\prod_(q <- l | q!= p) q.2) p.2).1) %% p.2). - - Lemma egcd_coprime (a b : nat) : - 0 < a -> - coprime a b -> - (egcdn a b).1 * a = 1 %[mod b]. - Proof. - intros. - case: egcdnP => //= km kn ->. - by rewrite (eqP H0) -modnDm modnMl add0n modn_mod. - Qed. - - Lemma egcd_coprime_mult (a b c : nat) : - 0 < a -> - coprime a b -> - c * (egcdn a b).1 * a = c %[mod b]. - Proof. - intros. - rewrite -mulnA -modnMmr egcd_coprime //. - rewrite modnS mod0n dvdn1. - case: eqVneq => [-> |]. - - by rewrite muln0 !modn1. - - by rewrite muln1. - Qed. - - Lemma sym_pairwiseP {T : Type} (r : T -> T -> bool) (sym:symmetric r) (x0 : T) (xs : seq T) : - reflect {in gtn (size xs) &, {homo nth x0 xs : i j / i <> j >-> r i j}} (pairwise r xs). - Proof. - induction xs; simpl; first by apply (iffP idP). - apply: (iffP andP). - - intros [??]?????. - destruct x; destruct y; simpl. - + lia. - + move/(all_nthP x0): H. - apply. - rewrite/in_mem/mem/= in H2. - lia. - + rewrite sym. - move/(all_nthP x0): H. - apply. - rewrite/in_mem/mem/= in H1. - lia. - + rewrite H0 in IHxs. - inversion IHxs. - apply H4; trivial. - lia. - - intros ?. - split. - + apply/(all_nthP x0)=> i ilt. - move: (H 0 i.+1) => /=. - apply; rewrite /in_mem/mem/=; lia. - + inversion IHxs => //. - elim H1 => x y xlt ylt xny. - move: (H x.+1 y.+1). - unfold in_mem, mem in *; simpl in *. - apply; lia. - Qed. - - Lemma allrel_sym {A:eqType} f (l1 l2: seq A) : - symmetric f -> - allrel f l1 l2 = allrel f l2 l1. - Proof. - rewrite allrelC=>sym. - apply eq_allrel => x y. - by rewrite sym. - Qed. - - Lemma pairwise_perm_sym {A:eqType} f (l1 l2: seq A) : - symmetric f -> - perm_eq l1 l2 -> - pairwise f l1 = pairwise f l2. - Proof. - move=> symf. - move: l1 l2. - wlog pimp: / forall l1 l2, perm_eq l1 l2 -> pairwise f l1 -> pairwise f l2. - - apply. - apply catCA_perm_ind=> l1 l2 l3. - rewrite !pairwise_cat !allrel_catr. - move/andP=>[/andP-[rel12 rel13] /andP-[p1 /andP-[rel23 /andP-[p2 p3]]]]. - repeat (try (apply/andP; split)) => //. - by rewrite allrel_sym. - - move=> l1 l2 pm. - apply Bool.eq_bool_prop_intro. - split =>/Bool.Is_true_eq_true=> HH - ; apply Bool.Is_true_eq_left - ; eapply pimp; try apply HH. - + by []. - + by rewrite perm_sym. - Qed. - - Lemma pairwise_coprime_perm l l2: - perm_eq l l2 -> - pairwise coprime l = pairwise coprime l2. - Proof. - intros. - apply pairwise_perm_sym => // x y. - apply coprime_sym. - Qed. - - Lemma prodn_filter1 [I : Type] (r : seq I) (F : I -> nat) : - \prod_(i <- r | F i != 1) F i = \prod_(i <- r) F i. - Proof. - induction r. - - by rewrite !big_nil. - - rewrite !big_cons -IHr. - case: eqVneq => /= [->|//]. - lia. - Qed. - - Lemma balanced_chinese_list_mod_inner (l : seq (nat * nat)) : - (forall p, p \in l -> 0 < p.2) -> - pairwise coprime (map snd l) -> - forall p, - p \in l -> - \prod_(q <- l | q != p) q.2 * ((p.1 * (egcdn (\prod_(q <- l | q != p) q.2) p.2).1) %% p.2) == p.1 %[mod p.2]. - Proof. - intros. - apply /eqP. - rewrite modnMmr mulnC. - apply egcd_coprime_mult. - - rewrite big_seq_cond. - apply prodn_cond_gt0 => i. - move/andP => [iinl _]. - by apply H. - - rewrite (perm_big_AC mulnA mulnC _ (r2:=p :: rem p l)). - + case: (eqVneq p.2 1) => [->|neq]. - { - by rewrite coprimen1. - } - rewrite -big_filter /= eqxx /=. - rewrite (_ : (\prod_(i <- [seq i <- rem p l | i != p]) i.2) = (\prod_(i <- (map snd [seq x <- rem p l | x.2 != 1]) | (i != p.2)) i)). - * rewrite coprime_sym -big_filter. - apply all_coprime_prod. - apply pairwise_coprime_cons. - rewrite (pairwise_perm_sym coprime_sym (perm_map snd (perm_to_rem H1))) in H0. - revert H0. - apply subseq_pairwise. - rewrite /= eqxx. - rewrite filter_map. - apply map_subseq. - eapply subseq_trans. - -- apply filter_subseq. - -- apply filter_subseq. - * rewrite big_map. - transitivity (\prod_(i <- [seq i <- rem p l | i != p] | i.2 != 1) i.2) - ; [by rewrite prodn_filter1 |]. - rewrite -big_filter. - symmetry; rewrite -big_filter. - rewrite -!filter_predI /predI. - f_equal. - apply eq_in_filter => x xin /=. - case: (eqVneq x.2 1) => /=; [by rewrite Bool.andb_false_r |intros ne1]. - case: (eqVneq x p). - { move => ->. - by rewrite eqxx. - } - case: (eqVneq x.2 p.2) => //= eqq2. - rewrite /eq_op/= eqq2 eqxx Bool.andb_true_r => eqq1. - move: H0. - rewrite (pairwise_perm_sym coprime_sym (perm_map snd (perm_to_rem H1))) /=. - move/andP=> []. - have x2in: x.2 \in [seq i.2 | i <- rem p l] by apply map_f. - move/allP/(_ x.2 x2in). - rewrite eqq2 /coprime gcdnn -eqq2 => eqq3. - by rewrite eqq3/= in ne1. - + by apply perm_to_rem. - Qed. - - Lemma pairwise_coprime_uniq (l : seq nat) : - (forall p, p \in l -> 1 < p) -> - pairwise coprime l -> - uniq l. - Proof. - intros. - rewrite uniq_pairwise. - have: (pairwise (fun x y => coprime x y && (1 x y xlt ylt xlty. - apply H. - by apply mem_nth. - } - apply sub_pairwise. - intros ???. - simpl. - assert (x = y -> false). - { - intros. - rewrite H2 in H1. - move /andP in H1; destruct H1. - rewrite /coprime gcdnn in H1. - move /eqP in H1. - lia. - } - by apply (contra_not_neq H2). - Qed. - - Lemma pairwise_coprime_uniq_pair (l : seq (nat * nat)) : - (forall p, p \in l -> 1 < p.2) -> - pairwise coprime (map snd l) -> - uniq l. - Proof. - intros. - apply (map_uniq (f := snd)). - apply pairwise_coprime_uniq; trivial. - intros. - move /mapP in H1. - destruct H1. - rewrite H2. - by apply H. - Qed. - - Lemma prod_split1 (l : seq (nat * nat)) (p : nat*nat) : - uniq l -> - p \in l -> - \prod_(q<-l) q.2 = p.2 * \prod_(q <- l | q != p) q.2. - Proof. - intros. - rewrite (big_rem_AC mulnA mulnC _ (z := p)) //. - f_equal. - rewrite -big_filter. - symmetry; rewrite -big_filter. - replace [seq _ <- rem p l | true] with [seq i <- l | i != p]; trivial. - rewrite rem_filter // filter_predT. - apply eq_filter. - by rewrite /predC1 /=. - Qed. - - Lemma balanced_chinese_list_mod_inner_lt (l : seq (nat * nat)) : - uniq l -> - (forall p, p \in l -> 0 < p.2) -> - forall p, - p \in l -> - \prod_(q <- l | q != p) q.2 * ((p.1 * (egcdn (\prod_(q <- l | q != p) q.2) p.2).1) %% p.2) < \prod_(q <- l) q.2. - Proof. - intros. - rewrite (prod_split1 H H1) mulnC ltn_pmul2r. - - apply ltn_pmod. - by apply H0. - - rewrite big_seq_cond. - apply prodn_cond_gt0 => i. - intros. - apply H0. - move /andP in H2. - tauto. - Qed. - - Lemma sum_list_const {T} (l : list T) (c : nat) : - \sum_(p <- l) c = (size l) * c. - Proof. - induction l. - - by rewrite big_nil /= mul0n. - - rewrite big_cons IHl /=. - lia. - Qed. - - Lemma big_sum_le_const {T:eqType} (l : list T) (F : T -> nat) (c : nat) : - (forall p, p \in l -> F p <= c) -> - \sum_(p <- l) F p <= (size l)*c. - Proof. - move=> inle. - rewrite -sum_list_const. - rewrite !big_seq. - apply leq_sum => i ini. - by apply inle. - Qed. - - Lemma big_sum_lt_const {T:eqType} (l : list T) (F : T -> nat) (c : nat) : - (forall p, p \in l -> F p < c) -> - 0 < size l -> - \sum_(p <- l) F p < (size l)*c. - Proof. - move=> inle lnnil. - destruct l. - - by rewrite !big_nil. - - rewrite !big_cons /= mulSn. - have lt1: F s < c. - { - apply inle. - apply mem_head. - } - rewrite -(ltn_add2r (size l * c)) in lt1. - eapply leq_ltn_trans; [|apply lt1]. - rewrite leq_add2l. - apply big_sum_le_const => x xin. - apply ltnW. - apply inle. - by rewrite in_cons xin orbT. - Qed. - - Lemma balanced_chinese_list_mod_lt (l : seq (nat * nat)) : - (forall p, p \in l -> 1 < p.2) -> - pairwise coprime (map snd l) -> - 0 < (size l) -> - balanced_chinese_list l < (size l) * \prod_(q <- l) q.2. - Proof. - intros. - apply big_sum_lt_const; trivial. - intros. - apply balanced_chinese_list_mod_inner_lt; trivial. - - by apply pairwise_coprime_uniq_pair. - - intros. - apply H in H3. - lia. - Qed. - - Lemma modn_add0 (m a b : nat) : - b == 0 %[mod m] -> - a %% m + b == a %[mod m]. - Proof. - intros. - move /eqP in H. - rewrite modnDml -modnDm H mod0n addn0 modn_mod //. - Qed. - - Lemma modn_mull0 (m a b : nat) : - a %% m = 0 -> - (a * b) %% m = 0. - Proof. - intros. - rewrite -(mod0n m) -modnMml H mul0n //. - Qed. - - Lemma balanced_chinese_list_mod1 (l : seq (nat * nat)) : - (forall p, p \in l -> 1 < p.2) -> - pairwise coprime (map snd l) -> - forall p, - p \in l -> - balanced_chinese_list l == p.1 %[mod p.2]. - Proof. - intros. - rewrite -modn_summ (bigD1_seq p) /= //. - - have posl: (forall p, p \in l -> 0 < p.2). - { - intros ??. - apply H in H2. - lia. - } - rewrite (eqP (balanced_chinese_list_mod_inner posl H0 H1)). - apply modn_add0. - rewrite big1_idem //. - intros. - apply modn_mull0. - rewrite -big_filter. - rewrite (perm_big_AC mulnA mulnC _ (r2:=p :: rem p [seq i0 <- l | i0 != i])). - + by rewrite big_cons modnMr. - + apply perm_to_rem. - by rewrite mem_filter H1 eq_sym H2. - - by apply pairwise_coprime_uniq_pair. - Qed. - - Lemma balanced_chinese_list_filter1 (l : seq (nat * nat)) : - balanced_chinese_list [seq x <- l | x.2 != 1] = balanced_chinese_list l. - Proof. - rewrite /balanced_chinese_list. - - move: (perm_filterC (fun a => a.2 != 1) l). - move/permPl => p1. - rewrite -(perm_big_AC addnA addnC _ _ _ p1). - rewrite big_cat /=. - - have ->: \sum_(i <-[seq x <- l | predC (fun a : nat * nat => a.2 != 1) x]) - \prod_(q <- l | q != i) q.2 * ((i.1 * (egcdn (\prod_(q <- l | q != i) q.2) i.2).1) %% i.2) = 0. - { - rewrite big_filter. - under eq_bigr. - { - move=> i /=. - case: eqVneq => //= -> _. - rewrite modn1 muln0. - over. - } - by rewrite big_const_seq iter_addn_0 mul0n. - } - rewrite addn0 !big_filter. - - apply eq_bigr => i ne1. - f_equal. - - rewrite -big_filter. - symmetry. - rewrite -big_filter -prodn_filter1 -big_filter. - f_equal. - rewrite -!filter_predI /predI. - apply eq_in_filter => x xin /=. - by rewrite andbC. - - do 4 f_equal. - rewrite -big_filter. - symmetry. - rewrite -big_filter -prodn_filter1 -big_filter. - f_equal. - rewrite -!filter_predI /predI. - apply eq_in_filter => x xin /=. - by rewrite andbC. - Qed. - - Lemma balanced_chinese_list_mod (l : seq (nat * nat)) : - (forall p, p \in l -> 0 < p.2) -> - pairwise coprime (map snd l) -> - forall p, - p \in l -> - balanced_chinese_list l == p.1 %[mod p.2]. - Proof. - intros. - case: (eqVneq p.2 1) => eqq. - { - by rewrite eqq !modn1. - } - have: balanced_chinese_list [seq x <- l | x.2 != 1] == p.1 %[mod p.2]. - - apply balanced_chinese_list_mod1. - + move=> nn. - rewrite mem_filter. - move/andP => [ne1 nnin]. - move: (H _ nnin) ne1 => /=. - destruct (nn.2) as [|[]]; lia. - + move: H0. - rewrite !pairwise_map. - apply pairwise_filter. - + by rewrite mem_filter H1 eqq. - - by rewrite balanced_chinese_list_filter1. - Qed. - - Lemma chinese_remainder_list_permutation (l l2: list (nat * nat)) : - pairwise coprime (map snd l) -> - perm_eq l l2 -> - let m := \prod_(i <- map snd l) i in - chinese_list l == chinese_list l2 %[mod m]. - Proof. - intros co_l perm. - apply chinese_remainder_list_unique; trivial. - intros. - assert (co_l2: pairwise coprime (map snd l2)). - { - rewrite (pairwise_coprime_perm (l2:=[seq i.2 | i <- l]))//. - apply perm_map. - by rewrite perm_sym. - } - move/mapP: H => [px ] in1 ->. - rewrite (eqP (chinese_remainder_list co_l in1)). - move: in1. - rewrite (perm_mem perm)=> in2. - by rewrite (eqP (chinese_remainder_list co_l2 in2)). - Qed. - - Definition Zp_reduce_r (p q : nat) (a : 'Z_(p*q)) : 'Z_q := inZp a. - Definition Zp_reduce_l (p q : nat) (a : 'Z_(p*q)) : 'Z_p := inZp a. - Definition Zp_reduce_pair (p q : nat) (a : 'Z_(p*q)) := (Zp_reduce_l a, - Zp_reduce_r a). - - Lemma modn_plus_const x a b m : - a = b %[mod m] -> - x + a = x + b %[mod m]. - Proof. - intros. - by rewrite -modnDm H modnDm. - Qed. - - Lemma Zp_reduce_r_is_morphism (p q : nat) : - 1 < p -> - 1 < q -> - rmorphism (@Zp_reduce_r p q). - Proof. - intros. - assert (1 < p*q) by lia. - assert ((Zp_trunc q).+2 %| (Zp_trunc (p * q)).+2). - { - rewrite !Zp_cast //. - apply dvdn_mull. - apply dvdnn. - } - constructor. - - intros ??. - apply val_inj; simpl. - rewrite modnDm modn_dvdm //. - apply (@modn_plus_const x (((Zp_trunc (p * q)).+2 - y) %% (Zp_trunc (p * q)).+2) - ((Zp_trunc q).+2 - y %% (Zp_trunc q).+2) (Zp_trunc q).+2). - rewrite modn_dvdm //. - destruct y. - simpl. - rewrite !Zp_cast //. - rewrite Zp_cast // in i. - clear H2 x. - rewrite modnB; try lia. - rewrite modnMl. - case (boolP (0 < m %% q)); intros; simpl. - + rewrite mul1n addn0. - assert (q - m%%q < q) by lia. - rewrite (modn_small H2) //. - + rewrite mul0n addn0. - assert (m%%q = 0) by lia. - rewrite H2 !subn0 modnn //. - - constructor. - + intros ??. - apply val_inj; simpl. - rewrite modnMm modn_dvdm // !Zp_cast //. - + apply val_inj; simpl. - rewrite modn_dvdm // !Zp_cast //. - Qed. - - Lemma Zp_reduce_l_is_morphism (p q : nat) : - 1 < p -> - 1 < q -> - rmorphism (@Zp_reduce_l p q). - Proof. - intros. - rewrite /Zp_reduce_l mulnC. - by apply Zp_reduce_r_is_morphism. - Qed. - - Lemma Zp_reduce_pair_is_morphism (p q : nat) : - 1 < p -> - 1 < q -> - rmorphism (@Zp_reduce_pair p q). - Proof. - intros. - destruct (@Zp_reduce_l_is_morphism p q H H0) as [? [? ?]]. - destruct (@Zp_reduce_r_is_morphism p q H H0) as [? [? ?]]. - constructor. - - intros ??. - rewrite /Zp_reduce_pair base base0 //. - - constructor. - + intros ??. - rewrite /Zp_reduce_pair m m0 //. - + rewrite /Zp_reduce_pair e e0 //. - Qed. - - Definition Zp_lift_pair (p q : nat) (r : 'Z_p * 'Z_q) : 'Z_(p*q) := - inZp (chinese p q r.1 r.2). - - Lemma modn_muln_l x p q : - (x %% (p * q)) %% p = x %% p. - Proof. - symmetry. - apply /eqP. - have HH: (x %% (p * q) <= x) by apply leq_mod. - rewrite eqn_mod_dvd //. - apply mul_dvdn_l with (d2 := q). - by rewrite -(eqn_mod_dvd (p * q)) // modn_mod. - Qed. - - Lemma modn_muln_r x p q : - (x %% (p * q)) %% q = x %% q. - Proof. - by rewrite mulnC modn_muln_l. - Qed. - - Lemma Zp_lift_pair_is_morphism (p q : nat) : - 1 < p -> - 1 < q -> - coprime p q -> - rmorphism (@Zp_lift_pair p q). - Proof. - intros. - assert (1 < p*q) by lia. - generalize (chinese_remainder H1); intros. - generalize (chinese_modl H1); intros. - generalize (chinese_modr H1); intros. - constructor. - - intros ??. - unfold Zp_lift_pair. - rewrite -inZp_add. - apply val_inj. - destruct x, y. - destruct s, s0, s1, s2. - rewrite /= !Zp_cast //. - apply /eqP. - rewrite H3. - apply /andP. - split; apply /eqP. - + rewrite H4. - symmetry. - rewrite -modnDm H4 modnB; [|lia|lia]. - rewrite modn_muln_l modnMr addn0. - case (boolP (0 < (chinese p q m1 m2 %% p))); simpl; intros. - * rewrite mul1n H4. - symmetry. - rewrite -modnDm !modn_mod. - apply modn_plus_const. - rewrite Zp_cast in i1; [|lia]. - rewrite modnB; [|lia|lia]. - rewrite modnn (modn_small i1) addn0. - case (boolP (0 < m1)); simpl; intros. - -- by rewrite mul1n. - -- assert (m1 = 0) by lia. - by rewrite H6 mul0n !subn0 modnn mod0n. - * assert (chinese p q m1 m2 %% p = 0) by lia. - rewrite mul0n H4. - symmetry. - rewrite -modnDm !modn_mod. - apply modn_plus_const. - rewrite Zp_cast in i1; [|lia]. - rewrite modnB; [|lia|lia]. - rewrite modnn (modn_small i1) addn0. - case (boolP (0 < m1)); simpl; intros. - -- rewrite mul1n. - rewrite H4 (modn_small i1) in H6. - lia. - -- by rewrite mul0n. - + rewrite H5. - symmetry. - rewrite -modnDm H5 modnB; [|lia|lia]. - rewrite modn_muln_r modnMl addn0. - case (boolP (0 < (chinese p q m1 m2 %% q))); simpl; intros. - * rewrite mul1n H5. - symmetry. - rewrite -modnDm !modn_mod. - apply modn_plus_const. - rewrite Zp_cast in i2; [|lia]. - rewrite modnB; [|lia|lia]. - rewrite modnn (modn_small i2) addn0. - case (boolP (0 < m2)); simpl; intros. - -- by rewrite mul1n. - -- assert (m2 = 0) by lia. - by rewrite H6 mul0n !subn0 modnn mod0n. - * assert (chinese p q m1 m2 %% q = 0) by lia. - rewrite mul0n H5. - symmetry. - rewrite -modnDm !modn_mod. - apply modn_plus_const. - rewrite Zp_cast in i2; [|lia]. - rewrite modnB; [|lia|lia]. - rewrite modnn (modn_small i2) addn0. - case (boolP (0 < m2)); simpl; intros. - -- rewrite mul1n. - rewrite H5 (modn_small i2) in H6. - lia. - -- by rewrite mul0n. - - constructor. - + intros ??. - unfold Zp_lift_pair. - rewrite -inZp_mul. - apply val_inj. - destruct x, y. - destruct s, s0, s1, s2. - rewrite /= !Zp_cast //. - apply /eqP. - rewrite H3. - apply /andP. - split; apply /eqP; symmetry. - * by rewrite -modnMm !H4 /= modnMm modn_mod. - * by rewrite -modnMm !H5 /= modnMm modn_mod. - + unfold Zp_lift_pair. - apply val_inj. - rewrite /= !Zp_cast //. - apply /eqP. - rewrite H3. - apply /andP. - split; apply /eqP. - * rewrite H4 !modn_small //. - * rewrite H5 !modn_small //. - Qed. - - Lemma right_inv_chinese (p q : nat) : - 1 < p -> - 1 < q -> - coprime p q -> - cancel (@Zp_lift_pair p q) (@Zp_reduce_pair p q). - Proof. - intros. - unfold cancel. - intros. - unfold Zp_reduce_pair, Zp_lift_pair. - destruct x. - simpl. - unfold Zp_reduce_l, Zp_reduce_r. - generalize (chinese_remainder H1); intros. - generalize (chinese_modl H1); intros. - generalize (chinese_modr H1); intros. - assert (1 < p * q) by lia. - destruct o, o0. - f_equal. - - apply val_inj. - rewrite /= !Zp_cast //. - rewrite Zp_cast // in i. - rewrite modn_muln_l H3. - by rewrite (modn_small i). - - apply val_inj. - rewrite /= !Zp_cast //. - rewrite Zp_cast // in i0. - rewrite modn_muln_r H4. - by rewrite (modn_small i0). - Qed. - - Lemma left_inv_chinese (p q : nat) : - 1 < p -> - 1 < q -> - coprime p q -> - cancel (@Zp_reduce_pair p q) (@Zp_lift_pair p q). - Proof. - intros. - unfold cancel. - intros. - unfold Zp_reduce_pair, Zp_lift_pair. - destruct x. - unfold Zp_reduce_l, Zp_reduce_r. - generalize (chinese_remainder H1); intros. - generalize (chinese_modl H1); intros. - generalize (chinese_modr H1); intros. - assert (1 < p * q) by lia. - apply val_inj. - simpl. - rewrite !Zp_cast //. - rewrite Zp_cast // in i. - replace m with (m %% (p * q)) at 3. - - apply /eqP. - rewrite H2. - apply /andP. - split. - + apply /eqP. - by rewrite H3 modn_mod. - + apply /eqP. - by rewrite H4 modn_mod. - - by rewrite (modn_small i). - Qed. - - Lemma bijective_reduce_pair (p q : nat) - (pbig: 1 < p) - (qbig: 1 < q) - (cop: coprime p q) : - bijective (@Zp_reduce_pair p q). - Proof. - eapply Bijective. - - by apply left_inv_chinese. - - by apply right_inv_chinese. - Qed. - - Lemma bijective_lift_pair (p q : nat) - (pbig: 1 < p) - (qbig: 1 < q) - (cop: coprime p q) : - bijective (@Zp_lift_pair p q). - Proof. - eapply Bijective. - - by apply right_inv_chinese. - - by apply left_inv_chinese. - Qed. - -End chinese. - -(* order of 3 mod 2^(n+2) = 2^n *) -(* show 3^(2^n) <> 1 mod 2^(n+3) *) - -From mathcomp Require Import ssrnat. - -Lemma n_n1_even j : - exists k, - j * (j + 1) = k.*2. -Proof. - assert (~~ odd(j * (j + 1))). - { - replace (j+1) with (S j) by lia. - rewrite oddM oddS. - by case: (odd j). - } - apply even_halfK in H. - exists ((j * (j + 1)) ./2). - by rewrite H. -Qed. - -Lemma modn_sub i j m : - i >= j -> - (i == j %[mod m]) = (i - j == 0 %[mod m]). -Proof. - move/eqn_mod_dvd->. - by rewrite mod0n. -Qed. - -Lemma modn_sub_iff i j m : - i >= j -> - i = j %[mod m] <-> i - j = 0 %[mod m]. -Proof. - move/modn_sub=>eqq. - split; move/eqP - ; [rewrite eqq | rewrite -eqq] - ; by move/eqP. -Qed. - - -Lemma subn_sqr_1 (x : nat) : - x^2-1 = (x + 1) * (x - 1). -Proof. - replace (x^2-1) with (x^2-1^2) by lia. - by rewrite mulnC -subn_sqr. -Qed. - -Lemma ord_odd_pow_2 j n : - (2*j+1)^(2^n.+1) = 1 %[mod 2^(n.+3)]. -Proof. - induction n. - - rewrite expn1 expnS expn1. - replace ((2 * j + 1) * (2 * j + 1)) with (4*(j*(j+1)) + 1) by lia. - destruct (n_n1_even j). - rewrite H /=. - replace (2^3) with 8 by lia. - replace (4 * (x.*2)) with (8 * x) by lia. - by rewrite -modnDm modnMr modnDmr. - - rewrite expnS (mulnC _ (2^n.+1)) expnM (expnS _ n.+3). - rewrite modn_sub_iff; [|lia]. - rewrite subn_sqr_1. - rewrite modn_sub_iff in IHn; [|lia]. - assert (exists k, - 2 * k = ((2 * j + 1) ^ 2 ^ n.+1 + 1)). - { - assert (~~ odd ((2 * j + 1) ^ 2 ^ n.+1 + 1)). - { - rewrite oddD oddX oddD. - replace (2 *j) with (j.*2) by lia. - rewrite odd_double /=. - lia. - } - apply even_halfK in H. - exists (((2 * j + 1) ^ 2 ^ n.+1 + 1)./2). - rewrite -H. - lia. - } - destruct H. - rewrite -H -mulnA -muln_modr. - replace 0 with (2*0) at 7 by lia. - rewrite -muln_modr. - f_equal. - by rewrite -modnMm IHn modnMm muln0. - Qed. - -Lemma ord_odd_pow_2' j n : - odd j -> - j^(2^n.+1) = 1 %[mod 2^(n.+3)]. -Proof. - intros. - generalize (ord_odd_pow_2 (j./2) n); intros. - generalize (odd_double_half j); intros. - replace (2 * j./2) with (j./2.*2) in H0 by lia. - rewrite H /= addnC in H1. - by rewrite -H1. -Qed. - -Lemma iotaSn0 m n : n != 0 -> - iota m n = m :: iota m.+1 n.-1. -Proof. - case: n => //=. -Qed. - -Lemma index_iotaSn0 m n : m < n -> - index_iota m n = m :: index_iota m.+1 n. -Proof. - rewrite /index_iota=> mltn. - rewrite iotaSn0; try lia. - do 2 f_equal. - lia. -Qed. - -(* -Lemma add4_pow2_mod j n : - (j + 4)^(2 ^n) = j^(2^n) + (2^n.+2)*j^(2^n-1) %[mod 2^n.+3]. -Proof. - rewrite (Pascal j 4 (2^n)) /=. - move: (@big_mkord _ 0 addn (2 ^ n).+1 predT (fun i => 'C(2 ^ n, i) * (j ^ (2 ^ n - i) * 4 ^ i)))=> /= <-. - rewrite index_iotaSn0 // big_cons. - rewrite index_iotaSn0 ?big_cons; [| lia]. - rewrite expn0 expn1 muln1 subn0 bin0 bin1 mul1n addnA. - rewrite (mulnC _ 4) mulnA. - replace (2^n*4) with (2^n.+2) by (rewrite !expnS; lia). - assert (\sum_(2 <= j0 < (2 ^ n).+1) 'C(2 ^ n, j0) * (j ^ (2 ^ n - j0) * 4 ^ j0) = 0 %[mod 2^n.+3]). - { - rewrite -modn_summ. - rewrite (eqP (_ : \sum_( _ <= _ < _ ) _ == 0)) //. - rewrite (big_nat_widenl _ 0) //. - move: (@big_mkord _ 0 addn (2 ^ n).+1 (fun i => (andb true (leq (S (S O)) i))) (fun i => ('C(2 ^ n, i) * (j ^ (2 ^ n - i) * 4 ^ i)) - %% 2 ^ n.+3)) => /= ->. - rewrite sum_nat_eq0. - apply/forallP => x. - apply/implyP => xbig. - assert (forall k, k < n -> ('C(2 ^ n, 2^k.+1) * 4 ^ 2^k.+1) %% 2 ^ n.+3 == 0)%N. - { - intros. - assert (exists q, 'C(2^n, 2^k.+1) = q * 2^(n-k.+1)). - { - admi t. - } - destruct H0. - rewrite H0. - replace 4 with (2^2) by lia. - rewrite -expnM -expnS. - replace (x0 * 2 ^ (n - k.+1) * 2 ^ 2 ^ k.+2) with - (x0 * 2^ (n - k.+1 + (2^k.+2))). - - assert (2^k.+2 >= k.+4). - { - clear H H0. - induction k; trivial. - assert (k.+4 < 2*2^k.+2) by lia. - by rewrite expnS. - } - rewrite -modnMm. - assert (exists q, (2 ^ (n - k.+1 + 2 ^ k.+2) = q * 2^n.+3)). - { - admi t. - } - destruct H2. - by rewrite H2 modnMl muln0 mod0n. - - by rewrite -mulnA expnD. - } - assert (('C(2 ^ n, x) * 4 ^ x) %% 2 ^ n.+3 == 0)%N. - { - (* 2^(n-1) | 'C(2^n, 2), ok for x = 2 *) - (* 2^n | 'C(2^n,3), ok for x = 3 *) - (* 2^(n-2) | 'C(2^n, 4), ok for x = 4 *) - (* enough to prove for x = 2^k, since sucessive values are better *) - admi t. - } - by rewrite (mulnC _ (4 ^ x)) mulnA -modnMm (eqP H0) mul0n mod0n. - } - by rewrite -modnDmr -modnDmr H !mod0n addn0. - -Admi tted. - -Lemma ord_pow_2_odd j n : - odd j -> - j ^ (2^n) = 1 %[mod 2^n.+3] -> - (j + 4)^(2^n) <> 1 %[mod 2^n.+3]. -Proof. - intros. - rewrite add4_pow2_mod -modnDm H0 modnDm addnC modn_sub_iff; [|lia]. - replace (2 ^ n.+2 * j ^ (2 ^ n - 1) + 1 - 1) with - (2 ^ n.+2 * j ^ (2 ^ n - 1)) by lia. - rewrite (expnS _ n.+2) (mulnC 2 _). - replace 0 with (2^n.+2 * 0) at 6 by lia. - rewrite -!muln_modr mod0n muln0. - apply /eqP. - rewrite muln_eq0. - apply /norP. - split; [lia |]. - by rewrite modn2 oddX H orbT. -Qed. -*) - -(* https://math.stackexchange.com/questions/459815/the-structure-of-the-group-mathbbz-2n-mathbbz *) - - -Lemma mod_mul_mul_0 a b m1 m2 : - a == 0 %[mod m1] && (b == 0 %[mod m2]) -> - a * b == 0 %[mod m1 * m2]. -Proof. - do 3 (rewrite eqn_mod_dvd; [|lia]). - rewrite !subn0. - move /andP=>[diva divb]. - by rewrite dvdn_mul. -Qed. - -Lemma mod_mul_mul_0_alt a b m1 m2 : - a = 0 %[mod m1] /\ (b = 0 %[mod m2]) -> - a * b = 0 %[mod m1 * m2]. -Proof. - move=>[/eqP? /eqP?]. - by apply/eqP/mod_mul_mul_0/andP. -Qed. - -Lemma mod_pow2_sqr_aux a b n : - b <= a -> - a = b %[mod 2^n.+1] -> - a^2 = b^2 %[mod 2^n.+2]. -Proof. - intros. - rewrite modn_sub_iff // in H0. - rewrite modn_sub_iff. - - rewrite subn_sqr. - rewrite expnS mulnC. - apply mod_mul_mul_0_alt. - split; trivial. - rewrite -modn_sub_iff // expnS in H0. - lia. - - by rewrite leq_sqr. - Qed. - -Lemma mod_pow2_sqr a b n : - a = b %[mod 2^n.+1] -> - a^2 = b^2 %[mod 2^n.+2]. -Proof. - case (boolP (b <= a)); intros. - - by apply mod_pow2_sqr_aux. - - symmetry. - symmetry in H. - apply mod_pow2_sqr_aux; lia. - Qed. - -Lemma ord_5_pow_2 n : - 5 ^ (2 ^ n) = 1 + 2^n.+2 %[mod 2^n.+3]. -Proof. - induction n. - - lia. - - rewrite expnS mulnC expnM. - apply mod_pow2_sqr in IHn. - rewrite IHn. - generalize (expnD (1 + 2^n.+2) 1 1); intros. - rewrite H !expn1. - replace ((1 + 2 ^ n.+2) * (1 + 2 ^ n.+2)) with - (1 + 2 * (2^n.+2) + (2^n.+2)*(2^n.+2)) by lia. - rewrite -expnD. - assert (2^(n.+2 + n.+2) = 0 %[mod 2^n.+4]). - { - replace (n.+2 + n.+2) with ((2 * n).+4) by lia. - rewrite !expnS -!muln_modr. - replace (2^(2*n) %% 2^n) with 0. - - rewrite muln0 mod0n //. - - rewrite mulnC expnM. - generalize (expnD (2 ^ n) 1 1); intros. - by rewrite H0 !expn1 -modnMm modnn muln0 mod0n. - } - by rewrite -modnDm H0 mod0n addn0 modn_mod (expnS _ (n.+2)). - Qed. - -Lemma add_exp_mod_p a b p : - prime p -> - (a + b)^p = a^p + b^p %[mod p]. -Proof. - move=> pprime. - rewrite expnDn. - move: (prime_gt0 pprime). - case: p pprime; [lia |]=> p pprime _. - rewrite big_ord_recr big_ord_recl /=. - rewrite bin0 binn subn0 subnn !expn0 !mul1n muln1. - rewrite addnC addnA. - rewrite -modnDmr -modn_summ. - suff/eqP->: \sum_(i < p) (('C(p.+1, bump 0 i) * (a ^ (p.+1 - bump 0 i) * b ^ bump 0 i)) %% p.+1) == 0 - by rewrite mod0n addn0 addnC. - rewrite sum_nat_eq0. - apply/forallP=> k. - apply/implyP=> _. - rewrite -modnMml (eqP (_ : 'C(p.+1, bump 0 k) == 0 %[mod p.+1])). - - by rewrite mod0n. - - rewrite (eqn_mod_dvd p.+1) // subn0 prime_dvd_bin //. - rewrite /bump. - destruct k; simpl; lia. -Qed. - -Lemma iffbP {x y : bool}: reflect (x <-> y) (x == y). -Proof. - case: x y; case - ; rewrite eqE /= /eqb /= - ; constructor - ; firstorder. -Qed. - -Lemma iffEq {x y : bool}: (x <-> y) <-> (x = y). -Proof. - case: x y; case; firstorder. - symmetry; firstorder. -Qed. - -Lemma prime_pow_dvd_aux k p n: - k <= p^n -> - forall j, - j <= n -> - (p ^ j %| k) = (p^j %| (p^n - k)). -Proof. - intros kle j jlt. - assert (p^j %| p^n). - { - by apply dvdn_exp2l. - } - apply iffEq. - split; intros. - - rewrite dvdn_sub //. - - by rewrite -(dvdn_subr (m := p^n)). -Qed. - -Lemma prime_pow_dvd j k p n : - prime p -> - ~~ (p %| j) -> - p^n.+1 %| j * k -> - p^n.+1 %| k. -Proof. - move=> pprime pndivj. - induction n. - - rewrite !expn1 Euclid_dvdM //. - case/orP => // eqq1. - by rewrite eqq1 in pndivj. - - intros. - assert (p^n.+1 %| j * k). - { - rewrite expnS in H. - by apply mul_dvdn_r in H. - } - generalize (divnK (IHn H0)); intros. - clear IHn H0. - rewrite -H1 expnS dvdn_pmul2r. - + rewrite -H1 mulnA expnS dvdn_pmul2r in H. - * rewrite Euclid_dvdM // in H. - move /orP in H. - destruct H; trivial. - by move /negP in pndivj. - * apply prime_gt1 in pprime; lia. - + apply prime_gt1 in pprime; lia. - Qed. - - -Lemma expn_boundr a b c : - 1 < a -> - a ^ b <= c -> - b <= c. -Proof. - move=> abig able. - rewrite -(@leq_exp2l a) // (leq_trans able) //. - by rewrite ltnW // ltn_expl. -Qed. - -Lemma max_prime_pow_dvd j p : - 1 < p -> - 0 < j -> - {k | (p^k %| j) && ~~ (p^k.+1 %| j)}. -Proof. - move=> pbig jnneg. - have exP: exists i : nat, [pred k | p ^ k %| j] i. - { - exists 0. - by rewrite /= expn0 dvd1n. - } - - have bounded: forall i : nat, [pred k | p ^ k %| j] i -> i <= j. - { - move=> i /=. - move/(dvdn_leq jnneg). - by apply expn_boundr. - } - - exists (ex_maxn exP bounded). - case: ex_maxnP=> i /= divi ub. - rewrite divi /=. - apply/negP. - move/ub. - by rewrite ltnn. -Qed. - -(* -Lemma prime_dvd_pow_m1_bin k p n : prime p -> k < p^n.+1 -> - ~~ (p %| 'C(p^n.+1-1, k)). -Proof. - intros. - apply /negP. - revert H0. - induction k. - - rewrite bin0. - intros ??. - apply prime_gt1 in H. - apply dvdn_leq in H1; lia. - - intros. - assert (k < p^n.+1) by lia. - specialize (IHk H1). - generalize (mul_bin_left (p^k-1) k); intros. - intros ?. - generalize (prime_gt1 H); intros. - replace (p^k-1-k) with (p^k-k.+1) in H2. - + assert (0 < k.+1) by lia. - destruct (max_prime_pow_dvd H4 H5). - assert (0< p^k -k.+1). - { - admi t. - } - destruct (max_prime_pow_dvd H4 H6). - assert (x = x0). - { - admi t. - } - move /andP in i. - move /andP in i0. - destruct i; destruct i0. - admi t. - + clear IHk H2 H3; lia. -Admi tted. -*) - -Lemma prime_power_dvd_mul_helper p k a b : - prime p -> - p ^ k %| a * b -> - exists j, (j<=k) && (p^j %| a) && (p^(k-j) %| b). -Proof. - intros. - induction k. - - exists 0. - by rewrite subn0 !expn0 !dvd1n. - - assert (p^k %| a*b). - { - rewrite expnS in H0. - by apply mul_dvdn_r in H0. - } - specialize (IHk H1). - destruct IHk. - move /andP in H2. - destruct H2. - move /andP in H2. - destruct H2. - generalize (divnK H3); intros. - generalize (divnK H4); intros. - rewrite -H5 -H6 expnS in H0. - assert (forall k, 0 < p^k). - { - intros. - rewrite expn_gt0. - apply prime_gt0 in H. - apply /orP. - tauto. - } - replace (a %/ p ^ x * p ^ x * (b %/ p ^ (k - x) * p ^ (k - x))) with - (a %/ p ^ x * (b %/ p ^ (k - x) * p ^ k)) in H0. - + rewrite mulnA in H0. - rewrite dvdn_pmul2r // in H0. - rewrite (Euclid_dvdM _ _ H) in H0. - move /orP in H0. - destruct H0. - * exists (x.+1). - apply /andP; split. - -- apply /andP; split; trivial. - rewrite -H6 expnS dvdn_pmul2r //. - -- replace (k.+1-x.+1) with (k-x); trivial. - * exists x. - apply /andP; split; trivial. - -- apply /andP; split; trivial. - lia. - -- rewrite -H5. - replace (k.+1-x) with (1 + (k - x)). - ++ rewrite expnD expn1 dvdn_pmul2r //. - ++ clear H1 H3 H4 H5 H0. - lia. - + clear H1 H3 H4 H5 H6 H0. - replace (p^k) with (p^x * p^(k-x)); try lia. - rewrite -expnD. - f_equal. - lia. - Qed. - -Lemma prime_pow_dvd_gen' j k p e n : - prime p -> - e <= n -> - (p^e %| j) -> - ~~ (p^e.+1 %| j) -> - p^n.+1 %| j * k -> - p^(n.+1-e) %| k. -Proof. - intros p_prim ele divj notdivj divjk. - generalize (divnK divj); intros. - rewrite -H in divjk. - rewrite mulnC mulnA in divjk. - replace (p^n.+1) with (p^(n.+1-e) * p^e) in divjk. - - rewrite dvdn_pmul2r in divjk. - assert (~~ (p %| (j %/ p ^ e))). - { - apply /negP. - intros ?. - apply prime_gt0 in p_prim. - assert (0 < p^e). - { - rewrite expn_gt0. - apply /orP. - tauto. - } - rewrite -(dvdn_pmul2r H1) in H0. - rewrite divnK // in H0. - rewrite expnS in notdivj. - move /negP in notdivj. - tauto. - } - + replace (n.+1-e) with ((n-e).+1) in divjk. - * rewrite mulnC in divjk. - generalize (prime_pow_dvd p_prim H0 divjk); intros. - replace (n.+1-e) with ((n-e).+1); trivial. - clear divj notdivj H H0 divjk H1; lia. - * clear divj notdivj H H0 divjk; lia. - + rewrite expn_gt0. - apply prime_gt0 in p_prim. - apply /orP. - tauto. - - rewrite -expnD. - f_equal. - clear divj notdivj H divjk; lia. - Qed. - -Lemma prime_pow_dvd_gen j k p e n : - prime p -> - e <= n.+1 -> - (p^e %| j) -> - ~~ (p^e.+1 %| j) -> - p^n.+1 %| j * k -> - p^(n.+1-e) %| k. -Proof. - intros p_prim ele divj notdivj divjk. - case (boolP (e <= n)). - - intros. - by apply prime_pow_dvd_gen' with (j := j). - - intros. - assert (e = n.+1) by lia. - by rewrite H subnn expn0 dvd1n. - Qed. - -(* -Lemma prime_power_dvd_mul p k a b : - prime p -> - p ^ k %| a * b = [exists j:(ordinal (k.-1.+1)), (p^j %| a) && (p^(k-j) %| b)]. -Proof. -(* move=> pprime. - case: (eqVneq a 0)=> [-> | neqq]. - - admi t. - - case: (@max_prime_pow_dvd a p). - + by apply prime_gt1. - + lia. - + move => x /andP-[pdiva pndiva]. - have: p^(k-x) %| b = (p ^ k %| a * b). - { - repeat case: dvdnP => //. - - move=> HH1 [j eqq1]; subst. - elim HH1. - move: pdiva. - move/dvdnP=> [jj eqq2]; subst. - exists (jj * j). - rewrite -!mulnA. - f_equal. - rewrite mulnC -!mulnA. - f_equal. - rewrite -expnD. - f_equal. - rewrite subnK //. - - - - } - - case: dvdnP. - apply/iffEq. - split=> HH. - * have pdvib: p^(x-k) %| b. - { - Search (_ * _ %| _ * _). - - move/andP. - - case: dvdnP=> HH; symmetry. - - destruct HH as [j eqq]. - - - - - induction k => /=. - - rewrite expn0 dvd1n. - symmetry. - apply/existsP. - simpl. - exists ord0. - by rewrite /= expn0 !dvd1n. - - rewrite expnS. - apply/iffEq. - split=> HH. - + have: p ^ k %| a * b by apply mul_dvdn_r in HH. - rewrite IHk. - move/existsP=>[j ]/andP-[diva divb]. - apply/existsP. - case_eq (dvdn (p^j.+1) a) => eqq1. - * have ordpf: j.+1 < k.+1. - { - admi t. - } - exists (Ordinal ordpf). - by rewrite /= eqq1 /= subSS. - * have ordpf: j < k.+1. - { - admi t. - } - exists (Ordinal ordpf). - rewrite /= diva /=. - - - dvdn_pmul2r: forall [p d m : nat], 0 < p -> (d * p %| m * p) = (d %| m) - dvdn_pmul2l: forall [p d m : nat], 0 < p -> (p * d %| p * m) = (d %| m) - *) - -Admi tted. -*) - -Lemma prime_pow_dvd_bin_full j k p n : - prime p -> 0 < k < p^n -> 0 < n -> - p^j %| k -> - ~~ (p^j.+1 %| k) -> - p^(n-j) %| 'C(p^n, k). -Proof. - intros. - have HH: (p^n %| k * 'C(p^n,k)). - { - destruct k; trivial. - by rewrite -mul_bin_diag dvdn_mulr. - } - assert (j < n). - { - assert (0 < k) by lia. - generalize (dvdn_leq H4 H2); intros. - clear H2 H3 HH. - apply ltn_pexp2l with (m := p); try lia. - by apply prime_gt0. - } - destruct n; trivial. - apply prime_pow_dvd_gen with (j := k); trivial. - lia. -Qed. - -Lemma prime_dvd_pow_bin k p n : - prime p -> - 0 < k < p -> - p^n.+1 %| 'C(p^n.+1, k). -Proof. - intros. - generalize (mul_bin_down (p^n.+1) k); intros. - have HH: (~~ (p %| p^n.+1-k)). - { - rewrite dvdn_subr. - - apply/negP. - move/dvdn_leq. - lia. - - apply (@leq_trans (k ^ n.+1)). - + rewrite -[k]expn1 leq_pexp2l //. - lia. - + rewrite leq_exp2r //. - lia. - - by rewrite dvdn_exp. - } - assert (p^n.+1 %| (p^n.+1-k) * 'C(p^n.+1,k)). - { - rewrite -H1. - apply dvdn_mulr. - apply dvdnn. - } - rewrite Gauss_dvdr // in H2. - rewrite coprimeXl //. - by rewrite prime_coprime. -Qed. - -(* -Lemma add_exp_mod_exp_p p k : - prime p -> - odd p -> - (1 + p)^(p^k) = 1 %[mod p^k.+1]. -Proof. - intros. - rewrite expnDn. - rewrite big_ord_recl /= bin0 exp1n !mul1n expn0. - rewrite -modnDmr -modn_summ. - suff/eqP-> : \sum_(i < p ^ k) 'C(p ^ k, bump 0 i) * (1 ^ (p ^ k - bump 0 i) * p ^ bump 0 i) %% (p ^ k.+1) == 0 - by rewrite mod0n addn0. - - rewrite sum_nat_eq0. - apply/forallP=> i. - apply/implyP=> _. - rewrite exp1n mul1n. - rewrite /bump /=. - - - - (* - rewrite bin_ffactd. - *) - -Admi tted. - - - -Lemma ord_p1_pow_p p n : - prime p -> - odd p -> - (1 + p)^(p^n) = 1 + p^n.+1 %[mod p^n.+2]. -Proof. - intros. - induction n. - - by rewrite expn0 !expn1. - - rewrite expnS expnS. - apply (f_equal (fun z => z^p)) in IHn. - Admi tted. - - *) - -Lemma ord_5_pow_2_neq n : - 5^(2^n) <> 1 %[mod 2^n.+3]. -Proof. - rewrite ord_5_pow_2 !expnS !modn_small; lia. -Qed. - -Lemma ord_5_pow_2_neq_m1 n : - 5^(2^n) <> 2^n.+3-1 %[mod 2^n.+3]. -Proof. - rewrite ord_5_pow_2 !expnS; lia. -Qed. - -Lemma ord_pow_gcd b e1 e2 n : - b^e1 = 1 %[mod n] -> - b^e2 = 1 %[mod n] -> - b^(gcdn e1 e2) = 1 %[mod n]. -Proof. - intros. - destruct e2. - - by rewrite gcdn0. - - assert (0 < e2.+1) by lia. - destruct (egcdnP e1 H1). - apply (f_equal (fun z => z^kn %% n)) in H. - rewrite !modnXm -expnM mulnC exp1n in H. - apply (f_equal (fun z => z^km %% n)) in H0. - by rewrite !modnXm -expnM mulnC e expnD exp1n -modnMm H modnMm mul1n gcdnC in H0. - Qed. - -From mathcomp Require Import poly zmodp. -Local Open Scope ring_scope. - -Lemma ord_pow2' (n : nat) (b : 'Z_(2^n.+3)): - b^+(2^n.+1) = 1 :> 'Z_(2^n.+3) -> - b^+(2^n) <> 1 :> 'Z_(2^n.+3) -> - (2^(S n)).-primitive_root b. -Proof. - intros. - by apply @two_pow_prim_root_alt. -Qed. - -Lemma zp_m1_neq1 (n : nat) : - n > 2 -> - -1 <> 1 :> 'Z_n. -Proof. - intros. - injection; unfold Zp_trunc; simpl. - replace (n.-2.+2) with n by lia. - have /modn_small->: (1 < n)%N by lia. - have /modn_small->: (n-1 < n)%N by lia. - lia. -Qed. - -Lemma unit_pow_2_Zp (n : nat) (b : 'Z_(2^n.+1)) : - b \is a unit <-> - odd b. -Proof. - have/(unitZpE b): (2^n.+1 > 1). - { - rewrite !expnS; lia. - } - rewrite (_ : (b%:R) = b) ?natr_Zp // => ->. - rewrite -coprimen2 coprime_sym coprime_pexpr; lia. -Qed. - -Lemma unit_pow_2_Zp' (n : nat) (b : {unit 'Z_(2^n.+1)}) : - odd (val b). -Proof. - by rewrite -unit_pow_2_Zp ?(valP b). -Qed. - - -Lemma ord_5_pow_2_Zp' n : - inZp (5 ^ (2^n)) = inZp (1 + 2^n.+2) :> 'Z_(2^n.+3). -Proof. - generalize (ord_5_pow_2 n); intros. - rewrite /inZp. - apply /eqP. - rewrite /eq_op /=. - rewrite Zp_cast; [| rewrite !expnS; lia]. - by apply /eqP. -Qed. - -Lemma ord_5_pow_2_Zp n : - inZp 5 ^+ (2^n) = inZp (1 + 2^n.+2) :> 'Z_(2^n.+3). -Proof. - rewrite -ord_5_pow_2_Zp'. - by rewrite inZp_exp. -Qed. - -Lemma ord_5_pow_2_Zp_1 n : - inZp 5 ^+ (2^n.+1) = 1 :> 'Z_(2^n.+3). -Proof. - assert (odd5:odd 5) by by []. - move: (ord_odd_pow_2' n odd5)=> b2n1_1. - rewrite -inZp_exp. - apply: val_inj => /=. - rewrite Zp_cast; try lia. - rewrite !expnS; lia. -Qed. - -Lemma ord_3_pow_2_Zp_1 n : - inZp 3 ^+ (2^n.+1) = 1 :> 'Z_(2^n.+3). -Proof. - assert (odd3:odd 3) by by []. - move: (ord_odd_pow_2' n odd3)=> b2n1_1. - rewrite -inZp_exp. - apply: val_inj => /=. - rewrite Zp_cast; try lia. - rewrite !expnS; lia. -Qed. - -Lemma primitive_5_pow2 n : - let b5 : 'Z_(2^n.+3) := inZp 5 in - (2^n.+1).-primitive_root b5. -Proof. - apply ord_pow2'. - - apply ord_5_pow_2_Zp_1. - - rewrite ord_5_pow_2_Zp. - intros ?. - apply (f_equal val) in H. - simpl in H. - rewrite Zp_cast in H; [|rewrite !expnS; lia]. - rewrite modn_small in H; [|rewrite !expnS; lia]. - rewrite modn_small in H; [|rewrite !expnS; lia]. - lia. -Qed. - -Lemma m1_neq_pow5_mod2n (n : nat) : - let b5 : 'Z_(2^n.+3) := inZp 5 in - not (exists k, b5^+k = -1). -Proof. - generalize (primitive_5_pow2 n); intros. - simpl in H. - generalize (two_pow_prim_root_m1_alt b5 n H); intros. - apply H0. - - apply zp_m1_neq1. - rewrite !expnS; lia. - - rewrite ord_5_pow_2_Zp. - rewrite /opp /= /Zp_opp. - intros ?. - apply (f_equal val) in H1. - simpl in H1. - rewrite Zp_cast in H1; [|rewrite !expnS; lia]. - rewrite modn_small in H1; [|rewrite !expnS; lia]. - rewrite modn_small in H1; [|rewrite !expnS; lia]. - rewrite modn_small in H1; [|rewrite !expnS; lia]. - rewrite !expnS in H1; lia. -Qed. - -Lemma m1_neq_pow5_mod2n_gen (n : nat) : - let b5 : 'Z_(2^n.+2) := inZp 5 in - not (exists k, b5^+k = -1). -Proof. - destruct n. - - simpl. - set b5 : 'Z_(2^0.+2) := inZp 5. - unfold not. - intros. - destruct H. - assert (b5 = 1). - { - apply val_inj; simpl. - rewrite /Zp_trunc /=. - lia. - } - rewrite H0 expr1n in H. - rewrite /opp /= /Zp_opp in H. - apply (f_equal val) in H. - simpl in H. - rewrite Zp_cast in H; lia. - - apply m1_neq_pow5_mod2n. -Qed. - - Lemma pow_3_5_pow_2 n : - 3^(2^n.+1) = 5^(2^n.+1) %[mod 2^n.+4]. - Proof. - induction n. - - lia. - - symmetry. - rewrite modn_sub_iff; [|rewrite leq_exp2r; lia]. - rewrite expnS !(mulnC 2%N _) !expnM subn_sqr. - symmetry in IHn. - rewrite modn_sub_iff in IHn; [|rewrite leq_exp2r; lia]. - rewrite (expnS _ n.+4) (mulnC 2%N _). - rewrite mod_mul_mul_0_alt; trivial. - split; trivial. - rewrite modn2 mod0n oddD !oddX. - lia. - Qed. - - Lemma ord_3_pow_2_neq n : - 3^(2^n) <> 1 %[mod 2^n.+3]. - Proof. - destruct n. - - lia. - - rewrite pow_3_5_pow_2. - apply ord_5_pow_2_neq. - Qed. - - Lemma ord_3_pow_2_neq_m1 n : - 3^(2^n) <> 2^n.+3-1 %[mod 2^n.+3]. - Proof. - destruct n. - - lia. - - rewrite pow_3_5_pow_2. - apply ord_5_pow_2_neq_m1. - Qed. - - Lemma primitive_3_pow2 n : - let b3 : 'Z_(2^n.+3) := inZp 3 in - (2^n.+1).-primitive_root b3. - Proof. - apply ord_pow2'. - - apply ord_3_pow_2_Zp_1. - - generalize (@ord_3_pow_2_neq n); intros. - unfold one; simpl. - unfold Zp1. - intros ?. - rewrite -inZp_exp in H0. - apply (f_equal val) in H0. - simpl in H0. - rewrite Zp_cast in H0; [|rewrite !expnS; lia]. - tauto. - Qed. - - Lemma m1_neq_pow3_mod2n (n : nat) : - let b3 : 'Z_(2^n.+3) := inZp 3 in - not (exists k, b3^+k = -1). -Proof. - generalize (primitive_3_pow2 n); intros. - simpl in H. - generalize (two_pow_prim_root_m1_alt b3 n H); intros. - apply H0. - - apply zp_m1_neq1. - rewrite !expnS; lia. - - generalize (@ord_3_pow_2_neq_m1 n); intros. - unfold opp; simpl. - unfold Zp_opp. - intros ?. - unfold b3 in H2. - rewrite -inZp_exp in H2. - apply (f_equal val) in H2. - simpl in H2. - rewrite Zp_cast in H2; [|rewrite !expnS; lia]. - rewrite H2 in H1. - clear H0 H2. - rewrite modn_small in H1. - + rewrite modn_small in H1; [|rewrite !expnS; lia]. - rewrite modn_small in H1; [|rewrite !expnS; lia]. - tauto. - + rewrite modn_small; [|rewrite !expnS; lia]. - lia. -Qed. - -From mathcomp Require Import finset eqtype finalg. -From mathcomp Require Import fingroup.quotient. -Section two_pow_units. - - Import GroupScope. -Lemma ord_unit_pow_2_Zp (n : nat) (b : {unit 'Z_(2^n.+3)}) : - b ^+ (2^n.+1) = 1. -Proof. - move: (unit_pow_2_Zp' b)=> bodd. - move: (ord_odd_pow_2' n bodd)=> b2n1_1. - move: (unit_Zp_expg b (2^n.+1)). - rewrite /inZp. - move/(f_equal val)=> /=. - rewrite {3}Zp_cast; [| rewrite !expnS; lia]. - rewrite b2n1_1 => eqq. - apply/eqP. - rewrite /eq_op /= /eq_op /= eqq. - rewrite Zp_cast; [| rewrite !expnS; lia]. - rewrite modn_small // !expnS; lia. -Qed. - -Lemma ord_unit_pow_2_Zp' (n : nat) (b : {unit 'Z_(2^n.+3)}) : - #[b] %| (2^n.+1)%N. -Proof. - rewrite order_dvdn. - apply /eqP. - apply ord_unit_pow_2_Zp. -Qed. - -Lemma dvdn_prime_power x p n : - prime p -> - x %| p^n.+1 -> - ~ x %| p^n -> - (x = p^n.+1)%N. -Proof. - intros p_prime x_n1 x_n. - generalize (prime_gt1 p_prime); intros pgt. - move /dvdn_pfactor in x_n1. - destruct (x_n1 p_prime). - rewrite H0 (dvdn_Pexp2l x0 n pgt) in x_n. - assert (x0 = n.+1) by lia. - by rewrite H1 in H0. -Qed. - -Lemma ord_unit_pow_2_Zp_max (n : nat) (b : {unit 'Z_(2^n.+3)}) : - b ^+ (2^n) <> 1 -> - #[b] = (2^n.+1)%N. -Proof. - intros. - generalize (ord_unit_pow_2_Zp' b); intros. - assert (~ #[b] %| 2^n). - { - intros ?. - rewrite order_dvdn in H1. - by move /eqP in H1. - } - by apply dvdn_prime_power. -Qed. - -Lemma card_units_pow_2_Zp (n : nat) : - #|units_Zp (2^n.+1)| = (2^n)%N. -Proof. - rewrite card_units_Zp; try lia. - rewrite totient_pfactor // /=. - lia. -Qed. - -Lemma unit_Zp_gens_ord (n : nat) (a b : {unit 'Z_n}) : - #|<[a]>%G :&: <[b]>%G| = 1%N -> - #|<[a]> * <[b]>| = (#[a] * #[b])%N. -Proof. - intros. - by rewrite mul_cardG H muln1. -Qed. - -Lemma unit_pow_2_Zp_gens_ord (n : nat) (a b : {unit 'Z_(2^n.+2)}) : - #[a] = 2%N -> - #[b] = (2^n)%N -> - #|<[a]>%G :&: <[b]>%G| = 1%N -> - #|<[a]> * <[b]>| = (2^n.+1)%N. -Proof. - intros. - rewrite (unit_Zp_gens_ord H1) H H0; trivial. - by rewrite expnS. -Qed. - -Lemma ord2_setI_G1 (n : nat) (a b : {unit 'Z_(2^n)}) : - #[a] = 2%N -> - (a \notin <[b]>) -> - <[a]>%G :&: <[b]>%G = 1. -Proof. - intros. - have ->: (<[a]> :&: <[b]> = [set 1]). - { - rewrite (cycle2g H) setIUl. - have /eqP->: ([set a] :&: <[b]> == set0). - { - by rewrite setI_eq0 disjoints1. - } - by rewrite setU0 -set1gE setI1g. - } - easy. -Qed. - -Lemma ord2_setI (n : nat) (a b : {unit 'Z_(2^n)}) : - #[a] = 2%N -> - (a \notin <[b]>) -> - #|<[a]>%G :&: <[b]>%G| = 1%N. -Proof. - intros. - rewrite ord2_setI_G1; trivial. - by rewrite cards1. -Qed. - -Lemma unit_pow_2_Zp_gens (n : nat) (a b : {unit 'Z_(2^n.+2)}) : - #[a] = 2%N -> - #[b] = (2^n)%N -> - a \notin <[b]> -> - <[a]> <*> <[b]> = [group of (units_Zp (2^n.+2)%N)]. -Proof. - intros. - generalize (subsetT (<[a]> * <[b]>)%G); intros. - apply index1g; trivial. - rewrite -(divgS H2) (card_units_pow_2_Zp n.+1) joinGE /= norm_joinEr /=. - - rewrite unit_pow_2_Zp_gens_ord //. - + rewrite divnn !expnS; lia. - + by apply ord2_setI. - - apply cents_norm. - eapply subset_trans. - apply subsetT. - apply sub_abelian_cent. - + apply units_Zp_abelian. - + apply subsetT. -Qed. - -Lemma unit_3_pow_2_Zp (n : nat) : - (3 : 'Z_(2^n.+1)) \is a unit. -Proof. - rewrite unitZpE. - - rewrite coprimeXl //. - - rewrite !expnS; lia. -Qed. - -Lemma unit_5_pow_2_Zp (n : nat) : - (5 : 'Z_(2^n.+1)) \is a unit. -Proof. - rewrite unitZpE. - - rewrite coprimeXl //. - - rewrite !expnS; lia. -Qed. - -Lemma unit_odd_pow_2_Zp (j n : nat): - odd j -> - (inZp j : 'Z_(2^n.+1)) \is a unit. -Proof. - intros. - rewrite unit_pow_2_Zp /= expnS Zp_cast; [|lia]. - rewrite odd_mod //. - replace (2 * 2^n)%N with ((2^n).*2) by lia. - by rewrite odd_double. -Qed. - -Lemma unit_3_pow_2_Zp' (n : nat) : - (3 : 'Z_(2^n.+1)) \is a unit. -Proof. - apply unit_odd_pow_2_Zp. - rewrite /= expnS Zp_cast; lia. -Qed. - -Lemma m1_not_in_unit_3_pow (n : nat) : - FinRing.unit 'Z_(2 ^ n.+3) (unitrN1 (Zp_finUnitRingType (Zp_trunc (2 ^ n.+3)))) - \notin <[FinRing.unit 'Z_(2 ^ n.+3) (unit_3_pow_2_Zp n.+2)]>. -Proof. - have small1: 1 < 2 ^ n.+3 by (rewrite !expnS; lia). - have small2: 2 < 2 ^ n.+3 by (rewrite !expnS; lia). - have small3: 3 < 2 ^ n.+3 by (rewrite !expnS; lia). - have nexist := @m1_neq_pow3_mod2n n. - apply/negP. - move/cyclePmin => [x xlt]. - move/(f_equal (fun (z : {unit 'Z_(2^n.+3)}) => val z)). - rewrite /= unit_Zp_expg /= {2 3 4 5 6}Zp_cast // !modn_small // /inZp. - move/(f_equal val) => /=. - rewrite !Zp_cast // modn_small; [| rewrite !expnS; lia]. - rewrite modn_small // => pow3m1. - apply nexist. - exists x. - rewrite /opp /= /Zp_opp {2}Zp_cast // -inZp_exp. - apply val_inj. - by rewrite /= !Zp_cast // -pow3m1 !modn_small //; rewrite !expnS; lia. -Qed. - -Lemma unit_pow_2_Zp_gens_m1_3 (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in - let u3 := FinRing.unit 'Z_(2^n.+3) (unit_3_pow_2_Zp n.+2) in - <[um1]> <*> <[u3]> = [group of (units_Zp (2^n.+3)%N)]. -Proof. - have small3: 3 < 2 ^ n.+3 by (rewrite !expnS; lia). - have small1: 1 < 2 ^ n.+3 by lia. - have small2: 2 < 2 ^ n.+3 by lia. - apply unit_pow_2_Zp_gens. - - apply nt_prime_order; trivial. - + apply val_inj. - by rewrite /= mulrNN mulr1. - + apply /eqP. - move/(f_equal FinRing.uval). - simpl. - by apply (zp_m1_neq1 small2). - - apply ord_unit_pow_2_Zp_max. - generalize (@ord_3_pow_2_neq n); intros. - move/(f_equal (fun (z : {unit 'Z_(2^n.+3)}) => val z)). - rewrite unit_Zp_expg /= {2 3 4 5 6}Zp_cast // !modn_small // /inZp. - move/(f_equal val) => /=. - rewrite !Zp_cast //. - - apply m1_not_in_unit_3_pow. -Qed. - - -Lemma unit_Z4_gens_m1 : - let um1 := FinRing.unit 'Z_4 (unitrN1 _) in - <[um1]> = [group of (units_Zp 4)]. -Proof. - intros. - generalize (subsetT (<[um1]>)); intros. - apply index1g; trivial. - rewrite -(divgS H) (card_units_pow_2_Zp 1). - assert (#[um1] = 2%N). - { - apply nt_prime_order; trivial. - apply val_inj. - by rewrite /= mulrNN mulr1. - } - unfold order in H0. - rewrite H0. - lia. -Qed. - -Lemma unit_Z2 : - <[1]> = [group of (units_Zp 2)]. -Proof. - generalize (card_units_pow_2_Zp 0); intros. - rewrite expn0 expn1 in H. - apply card1_trivg in H. - rewrite H. - apply /eqP. - apply cycle_eq1. -Qed. - -Lemma unit_Z2_alt : - let um1 := FinRing.unit 'Z_2 (unitrN1 _) in - <[um1]> = [group of (units_Zp 2)]. -Proof. - rewrite -unit_Z2. - intros. - f_equal. - unfold um1. - apply val_inj; simpl. - apply val_inj; simpl. - rewrite Zp_cast; lia. -Qed. - -Lemma unit_pow_2_Zp_gens_m1_3_gen (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+2) (unitrN1 _) in - let u3 := FinRing.unit 'Z_(2^n.+2) (unit_3_pow_2_Zp n.+1) in - <[um1]> <*> <[u3]> = [group of (units_Zp (2^n.+2)%N)]. -Proof. - destruct n. - - intros. - rewrite -unit_Z4_gens_m1. - assert (<[u3]> \subset <[um1]>). - { - rewrite cycle_subG. - assert (u3 = um1). - { - unfold u3, um1. - apply val_inj; simpl. - apply val_inj; simpl. - rewrite Zp_cast; lia. - } - rewrite H. - apply cycle_id. - } - move /joing_idPl in H. - rewrite H. - assert (um1 = FinRing.unit 'Z_4 (unitrN1 (Zp_finUnitRingType (Zp_trunc 4)))). - { - by apply val_inj; simpl. - } - by rewrite H0. - - apply unit_pow_2_Zp_gens_m1_3. -Qed. - -Lemma unit_pow_2_Zp_gens_m1_3_gen_gen (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+1) (unitrN1 _) in - let u3 := FinRing.unit 'Z_(2^n.+1) (unit_3_pow_2_Zp n) in - <[um1]> <*> <[u3]> = [group of (units_Zp (2^n.+1)%N)]. -Proof. - destruct n. - - intros. - rewrite -unit_Z2_alt. - assert (<[u3]> \subset <[um1]>). - { - rewrite cycle_subG. - assert (u3 = um1). - { - unfold u3, um1. - apply val_inj; simpl. - apply val_inj; simpl. - rewrite Zp_cast; lia. - } - rewrite H. - apply cycle_id. - } - move /joing_idPl in H. - rewrite H. - assert (um1 = FinRing.unit 'Z_2 (unitrN1 (Zp_finUnitRingType (Zp_trunc 2)))). - { - by apply val_inj; simpl. - } - by rewrite H0. - - apply unit_pow_2_Zp_gens_m1_3_gen. -Qed. - -Lemma m1_not_in_unit_5_pow (n : nat) : - FinRing.unit 'Z_(2 ^ n.+3) (unitrN1 (Zp_finUnitRingType (Zp_trunc (2 ^ n.+3)))) - \notin <[FinRing.unit 'Z_(2 ^ n.+3) (unit_5_pow_2_Zp n.+2)]>. -Proof. - have small5: 5 < 2 ^ n.+3 by (rewrite !expnS; lia). - have small1: 1 < 2 ^ n.+3 by lia. - have small2: 2 < 2 ^ n.+3 by lia. - have small3: 3 < 2 ^ n.+3 by lia. - have small4: 4 < 2 ^ n.+3 by lia. - generalize (@m1_neq_pow5_mod2n n); intros. - apply/negP. - move/cyclePmin => [x xlt]. - move/(f_equal (fun (z : {unit 'Z_(2^n.+3)}) => val z)). - rewrite /= unit_Zp_expg /= {2 3 4 5 6 7 8 9 10}Zp_cast // !modn_small // /inZp. - move/(f_equal val) => /=. - rewrite !Zp_cast // modn_small; [| rewrite !expnS; lia]. - rewrite modn_small // => HH. - apply H. - exists x. - rewrite /opp /= /Zp_opp {2}Zp_cast // -inZp_exp. - apply val_inj. - by rewrite /= !Zp_cast // -HH !modn_small //; rewrite !expnS; lia. -Qed. - -Lemma m1_not_in_unit_5_pow_gen (n : nat) : - FinRing.unit 'Z_(2 ^ n.+2) (unitrN1 (Zp_finUnitRingType (Zp_trunc (2 ^ n.+2)))) - \notin <[FinRing.unit 'Z_(2 ^ n.+2) (unit_5_pow_2_Zp n.+1)]>. -Proof. - destruct n. - - generalize (@m1_neq_pow5_mod2n_gen 0); intros. - apply /negP. - move/cyclePmin => [x xlt]. - move/(f_equal (fun (z : {unit 'Z_(2^0.+2)}) => val z)). - rewrite /= unit_Zp_expg /= {2 3 4 5 6 7 8 9 10}Zp_cast //. - have small3: 3 < 2 ^ 0.+2 by (rewrite !expnS; lia). - have small1: 1 < 2 ^ 0.+2 by lia. - have small2: 2 < 2 ^ 0.+2 by lia. - rewrite (modn_small small3) (modn_small small1) // /inZp. - move/(f_equal val) => /=. - rewrite !Zp_cast //. - rewrite (modn_small small1) exp1n. - rewrite !modn_small; try lia. - - apply m1_not_in_unit_5_pow. -Qed. - -Lemma unit_pow_2_Zp_gens_m1_5 (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in - let u5 := FinRing.unit 'Z_(2^n.+3) (unit_5_pow_2_Zp n.+2) in - <[um1]> <*> <[u5]> = [group of (units_Zp (2^n.+3)%N)]. -Proof. - have small5: 5 < 2 ^ n.+3 by (rewrite !expnS; lia). - have small1: 1 < 2 ^ n.+3 by lia. - have small2: 2 < 2 ^ n.+3 by lia. - have small3: 3 < 2 ^ n.+3 by lia. - have small4: 4 < 2 ^ n.+3 by lia. - apply unit_pow_2_Zp_gens. - - apply nt_prime_order; trivial. - + apply val_inj. - by rewrite /= mulrNN mulr1. - + apply /eqP. - move/(f_equal FinRing.uval). - simpl. - by apply (zp_m1_neq1 small2). - - apply ord_unit_pow_2_Zp_max. - generalize (@ord_5_pow_2_neq n); intros. - move/(f_equal (fun (z : {unit 'Z_(2^n.+3)}) => val z)). - rewrite unit_Zp_expg /= {2 3 4 5 6 7 8 9 10}Zp_cast // !modn_small // /inZp. - move/(f_equal val) => /=. - rewrite !Zp_cast //. - - apply m1_not_in_unit_5_pow. -Qed. - -Lemma unit_pow_2_Zp_gens_m1_3_alt (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in - let u3 := FinRing.unit 'Z_(2^n.+3) (unit_3_pow_2_Zp n.+2) in - <[um1]> * <[u3]> = [group of (units_Zp (2^n.+3)%N)]. -Proof. - rewrite <- unit_pow_2_Zp_gens_m1_3. - symmetry. - apply comm_joingE. - apply centC. - apply cents_cycle. - apply val_inj. - now rewrite /mulg /FinRing.unit_mul /= /mul /= Zp_mulC. -Qed. - -Lemma unit_pow_2_Zp_gens_m1_3_alt_gen (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+2) (unitrN1 _) in - let u3 := FinRing.unit 'Z_(2^n.+2) (unit_3_pow_2_Zp n.+1) in - <[um1]> * <[u3]> = [group of (units_Zp (2^n.+2)%N)]. -Proof. - rewrite <- unit_pow_2_Zp_gens_m1_3_gen. - symmetry. - apply comm_joingE. - apply centC. - apply cents_cycle. - apply val_inj. - now rewrite /mulg /FinRing.unit_mul /= /mul /= Zp_mulC. -Qed. - -Lemma unit_pow_2_Zp_gens_m1_5_alt (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in - let u5 := FinRing.unit 'Z_(2^n.+3) (unit_5_pow_2_Zp n.+2) in - <[um1]> * <[u5]> = [group of (units_Zp (2^n.+3)%N)]. -Proof. - rewrite <- unit_pow_2_Zp_gens_m1_5. - symmetry. - apply comm_joingE. - apply centC. - apply cents_cycle. - apply val_inj. - now rewrite /mulg /FinRing.unit_mul /= /mul /= Zp_mulC. -Qed. - -Lemma unit_pow_2_2_alt : - let um1 := FinRing.unit 'Z_(2^2) (unitrN1 _) in - let u5 := FinRing.unit 'Z_(2^2) (unit_5_pow_2_Zp 1) in - <[um1]> * <[u5]> = [group of (units_Zp (2^2))]. -Proof. - intros. - generalize unit_Z4_gens_m1; intros. - assert (u5 = 1). - { - unfold u5. - apply val_inj; simpl. - unfold Zp_trunc. - apply val_inj; simpl. - now rewrite modn_small. - } - generalize (cycle_eq1 u5); intros. - move /eqP in H0. - rewrite -H1 in H0. - move /eqP in H0. - rewrite H0. - rewrite -comm_joingE. - - rewrite joingG1. - rewrite -H. - unfold um1. - by rewrite /Zp_trunc /=. - - apply commute1. -Qed. - -Lemma unit_pow_2_Zp_gens_m1_5_alt_gen (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+2) (unitrN1 _) in - let u5 := FinRing.unit 'Z_(2^n.+2) (unit_5_pow_2_Zp n.+1) in - <[um1]> * <[u5]> = [group of (units_Zp (2^n.+2)%N)]. -Proof. - destruct n. - - apply unit_pow_2_2_alt. - - apply unit_pow_2_Zp_gens_m1_5_alt. -Qed. - -Lemma unit_pow_2_Zp_gens_m1_3_quo (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in - let u3 := FinRing.unit 'Z_(2^n.+3) (unit_3_pow_2_Zp n.+2) in - <[u3]>/<[um1]> = [group of (units_Zp (2^n.+3)%N)]/<[um1]>. -Proof. - intros. - rewrite - quotientMidl. - by rewrite unit_pow_2_Zp_gens_m1_3_alt. -Qed. - -Lemma unit_pow_2_Zp_gens_m1_3_quo_alt (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in - let u3 := FinRing.unit 'Z_(2^n.+3) (unit_3_pow_2_Zp n.+2) in - <[um1]>/<[u3]> = [group of (units_Zp (2^n.+3)%N)]/<[u3]>. -Proof. - intros. - rewrite - quotientMidr. - by rewrite unit_pow_2_Zp_gens_m1_3_alt. -Qed. - -Lemma unit_pow_2_Zp_gens_m1_5_quo (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+2) (unitrN1 _) in - let u5 := FinRing.unit 'Z_(2^n.+2) (unit_5_pow_2_Zp n.+1) in - <[u5]>/<[um1]> = [group of (units_Zp (2^n.+2)%N)]/<[um1]>. -Proof. - intros. - rewrite - quotientMidl. - by rewrite unit_pow_2_Zp_gens_m1_5_alt_gen. -Qed. - -Lemma unit_pow_2_Zp_gens_m1_5_quo_alt (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+2) (unitrN1 _) in - let u5 := FinRing.unit 'Z_(2^n.+2) (unit_5_pow_2_Zp n.+1) in - <[um1]>/<[u5]> = [group of (units_Zp (2^n.+2)%N)]/<[u5]>. -Proof. - intros. - rewrite - quotientMidr. - by rewrite unit_pow_2_Zp_gens_m1_5_alt_gen. -Qed. - -Lemma quotient_isog_abelian (gT : finGroupType) (A H G : {group gT}) : - abelian A -> - H \subset A -> - G \subset A -> - H :&: G = 1 -> - morphism.isog G (G / H). -Proof. - intros. - apply quotient_isog; trivial. - apply cents_norm. - by apply (sub_abelian_cent2 H0). -Qed. - -Lemma quotient_isog_unit_Zp (p : nat) (H G : {group {unit 'Z_p}}) : - H :&: G = 1 -> - morphism.isog G (G / H). -Proof. - intros. - apply quotient_isog; trivial. - apply cents_norm. - eapply subset_trans. - apply subsetT. - apply sub_abelian_cent. - + apply units_Zp_abelian. - + apply subsetT. -Qed. - -Lemma unitrN1_card_2 (n : nat) : - #[FinRing.unit 'Z_(2 ^ n.+2) (unitrN1 _)] = 2%N. -Proof. - rewrite -(expn1 2). - apply dvdn_prime_power; trivial. - - rewrite order_dvdn expn1. - apply /eqP. - apply val_inj. - by rewrite /= mulrNN mulr1. - - rewrite expn0 /not. - intros. - rewrite order_dvdn expg1 expn1 in H. - move /eqP in H. - rewrite /oneg /= in H. - apply (f_equal val) in H. - rewrite /= /opp /one /= in H. - apply (f_equal val) in H. - rewrite /Zp_trunc /= in H. - have small2: 2 < 2 ^ n.+2 by (rewrite !expnS; lia). - have small1: 1 < 2 ^ n.+2 by lia. - replace ((2^n.+2).-2.+2) with (2^n.+2)%N in H by lia. - rewrite (modn_small small1) modn_small in H; lia. -Qed. - -Lemma unit_pow_2_Zp_gens_m1_3_quo_isog (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in - let u3 := FinRing.unit 'Z_(2^n.+3) (unit_3_pow_2_Zp n.+2) in - morphism.isog <[u3]> (<[u3]>/<[um1]>). -Proof. - intros. - apply quotient_isog_unit_Zp. - apply ord2_setI_G1. - - apply unitrN1_card_2. - - apply m1_not_in_unit_3_pow. - Qed. - -Lemma unit_pow_2_Zp_gens_m1_3_quo_isog_um1 (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in - let u3 := FinRing.unit 'Z_(2^n.+3) (unit_3_pow_2_Zp n.+2) in - morphism.isog <[um1]> (<[um1]>/<[u3]>). -Proof. - intros. - apply quotient_isog_unit_Zp. - rewrite setIC. - apply ord2_setI_G1. - - apply unitrN1_card_2. - - apply m1_not_in_unit_3_pow. - Qed. - -Lemma unit_pow_2_Zp_gens_m1_3_quo_isog_alt (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in - let u3 := FinRing.unit 'Z_(2^n.+3) (unit_3_pow_2_Zp n.+2) in - morphism.isog <[u3]> ([group of (units_Zp (2^n.+3)%N)]/<[um1]>). -Proof. - intros. - simpl. - rewrite -unit_pow_2_Zp_gens_m1_3_quo. - apply unit_pow_2_Zp_gens_m1_3_quo_isog. -Qed. - -Lemma unit_pow_2_Zp_gens_m1_3_quo_isog_um1_alt (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in - let u3 := FinRing.unit 'Z_(2^n.+3) (unit_3_pow_2_Zp n.+2) in - morphism.isog <[um1]> ([group of (units_Zp (2^n.+3)%N)]/<[u3]>). -Proof. - intros. - simpl. - rewrite -unit_pow_2_Zp_gens_m1_3_quo_alt. - apply unit_pow_2_Zp_gens_m1_3_quo_isog_um1. -Qed. - -Lemma unit_pow_2_Zp_gens_m1_5_quo_isog (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+2) (unitrN1 _) in - let u5 := FinRing.unit 'Z_(2^n.+2) (unit_5_pow_2_Zp n.+1) in - morphism.isog <[u5]> (<[u5]>/<[um1]>). -Proof. - intros. - apply quotient_isog_unit_Zp. - apply ord2_setI_G1. - - apply unitrN1_card_2. - - apply m1_not_in_unit_5_pow_gen. -Qed. - -Lemma unit_pow_2_Zp_gens_m1_5_quo_isog_um1 (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+2) (unitrN1 _) in - let u5 := FinRing.unit 'Z_(2^n.+2) (unit_5_pow_2_Zp n.+1) in - morphism.isog <[um1]> (<[um1]>/<[u5]>). -Proof. - intros. - apply quotient_isog_unit_Zp. - rewrite setIC. - apply ord2_setI_G1. - - apply unitrN1_card_2. - - apply m1_not_in_unit_5_pow_gen. -Qed. - -Lemma unit_pow_2_Zp_gens_m1_5_quo_isog_alt (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+2) (unitrN1 _) in - let u5 := FinRing.unit 'Z_(2^n.+2) (unit_5_pow_2_Zp n.+1) in - morphism.isog <[u5]> ([group of (units_Zp (2^n.+2)%N)]/<[um1]>). -Proof. - intros. - simpl. - rewrite -unit_pow_2_Zp_gens_m1_5_quo. - apply unit_pow_2_Zp_gens_m1_5_quo_isog. -Qed. - -Lemma unit_pow_2_Zp_gens_m1_5_quo_isog_um1_alt (n : nat) : - let um1 := FinRing.unit 'Z_(2^n.+2) (unitrN1 _) in - let u5 := FinRing.unit 'Z_(2^n.+2) (unit_5_pow_2_Zp n.+1) in - morphism.isog <[um1]> ([group of (units_Zp (2^n.+2)%N)]/<[u5]>). -Proof. - intros. - simpl. - rewrite -unit_pow_2_Zp_gens_m1_5_quo_alt. - apply unit_pow_2_Zp_gens_m1_5_quo_isog_um1. -Qed. - -End two_pow_units. - -From mathcomp Require Import matrix perm. -Section add_self. - - Context {G:ringType}. - Context {n:nat} (npos:0 < n). - - Definition modn_ord (m:nat) : 'I_n := Ordinal (@ltn_pmod m n npos). - - Lemma modn_ord_inj (i : 'I_n) : - injective (fun j : ordinal_finType n => modn_ord (i + j)). - Proof. - rewrite /modn_ord => a b [xx]. - have: i + a + (n-i) = i + b + (n - i) %[mod n]. - { - rewrite -[addn (addn _ (nat_of_ord a)) _ %% n]modnDml. - rewrite -[addn (addn _ (nat_of_ord b)) _ %% n]modnDml. - by rewrite xx. - } - rewrite (_:(addn (addn i a) (n-i) = addn a n)). - - rewrite (_:(addn (addn i b) (n-i) = addn b n)). - + rewrite !modnDr !modn_small; try apply ltn_ord. - apply val_inj. - + move: (ltn_ord i); lia. - - move: (ltn_ord i); lia. - Qed. - - Definition rotate_index_right_ord (idx:'I_n) (e:nat) - := modn_ord (idx + e). - - Lemma rotate_ind_right_ord_cancel (e:nat) : - cancel (fun (idx : 'I_n) => rotate_index_right_ord idx e) - (fun (idx : 'I_n) => rotate_index_right_ord idx (n - e %% n)). - Proof. - rewrite /rotate_index_right_ord /cancel /modn_ord /=. - intros. - apply ord_inj=> /=. - rewrite -(modnDm x e n) modnDml -addnA. - generalize (ltn_pmod e npos); intros. - replace (e %% n + (n - e %% n))%N with n by lia. - rewrite -modnDmr modnn addn0 modn_mod modn_small //. - Qed. - - Definition rot_perm e := perm (can_inj (rotate_ind_right_ord_cancel e)). - - Lemma rot_mul e1 e2 : - perm_mul (rot_perm e1) (rot_perm e2) = rot_perm (e1 + e2). - Proof. - rewrite /perm_mul /rot_perm -permP. - move => x. - rewrite permE /= !permE /rotate_index_right_ord /modn_ord. - apply ord_inj=> /=. - by rewrite modnDml addnA. - Qed. - - Definition rotate_row_right (v:'rV[G]_n) (e:nat) - := \row_(i < n) v 0 (rotate_index_right_ord i e). - - Definition row_sum_naive_rot_row (v:'rV[G]_n) - := \sum_(i < n) rotate_row_right v i. - - Definition row_sum_naive_rot (v:'rV[G]_n) - := row_sum_naive_rot_row v 0 (Ordinal npos). - - Lemma row_sum_naive_rot_row_correct (v:'rV[G]_n) - : row_sum_naive_rot_row v = const_mx (\sum_(j < n) v 0 j). - Proof. - apply/matrixP => rr i. - rewrite /row_sum_naive_rot_row/const_mx/rotate_row_right/rotate_index_right_ord. - rewrite !ord1 !mxE summxE /=. - rewrite [\sum_(j modn_ord (i + nat_of_ord j)))). - - apply eq_bigr => k _. - by rewrite !mxE. - - apply modn_ord_inj. - Qed. - - Lemma row_sum_naive_rot_correct (v:'rV[G]_n) - : row_sum_naive_rot v = \sum_(j < n) v 0 j. - Proof. - by rewrite/row_sum_naive_rot row_sum_naive_rot_row_correct/const_mx !mxE. - Qed. - -End add_self. - -Section add_self_pow. - Context {G:comRingType}. - - Lemma expn_2_pos (n : nat) : 0 < expn 2 n. - Proof. - lia. - Qed. - - Fixpoint row_sum_rot_pow_rec {n} (v:'rV[G]_(2^n)) (m : nat) : 'rV[G]_(2^n) := - match m with - | 0 => v - | S m' => row_sum_rot_pow_rec (v + rotate_row_right (expn_2_pos n) v (2^m')) m' - end. - - Lemma vals_modn_mul_0 i m n : - 0 < m -> - 0 < n -> - i == 0 %[mod m] -> - exists k, k < n /\ i == k * m %[mod m * n]. - Proof. - intros. - assert (exists q, i = (q * m)%N). - { - rewrite (modn_small H) in H1. - by move /dvdnP in H1. - } - destruct H2. - exists (x %% n). - split. - - by rewrite ltn_mod. - - apply /eqP. - by rewrite mulnC H2 muln_modl modn_mod. - Qed. - - Lemma vals_modn_mul_le i j m n : - 0 < m -> - 0 < n -> - j <= i -> - i == j %[mod m] -> - exists k, k < n /\ i == j + k * m %[mod m * n]. - Proof. - intros. - rewrite modn_sub in H2; trivial. - apply (vals_modn_mul_0 (n:=n)) in H2; trivial. - destruct H2 as [? [??]]. - exists x. - split; trivial. - assert (x * m < m * n). - { - rewrite mulnC. - by rewrite ltn_pmul2l. - } - rewrite (modn_small H4) in H3. - move /eqP in H3. - apply (f_equal (fun z => modn (j + z) (m * n)%N)) in H3. - apply /eqP. - rewrite -H3 modnDmr. - clear H3. - replace (j + (i - j))%N with i; trivial. - lia. - Qed. - - Lemma vals_modn_mul i j m n : - 0 < m -> - 0 < n -> - i == j %[mod m] -> - exists k, k < n /\ i == j + k * m %[mod m * n]. - Proof. - case (boolP (j <= i)). - - intros; by apply vals_modn_mul_le. - - intros. - rewrite eq_sym in H1. - apply (vals_modn_mul_le (n := n)) in H1; try lia. - destruct H1 as [? [??]]. - destruct x. - + rewrite mul0n addn0 in H2. - exists 0%N. - split; trivial. - by rewrite mul0n addn0 eq_sym. - + exists (n - x.+1)%N. - split; try lia. - move /eqP in H2. - apply (f_equal (fun z => modn (z + (n - x.+1) * m)%N (m * n)%N)) in H2. - rewrite modnDml in H2. - rewrite H2 modnDml. - replace (i + x.+1 * m + (n - x.+1) * m)%N with (i + m * n)%N. - * by rewrite -modnDm modnn addn0 modn_mod. - * clear H2. - lia. - Qed. - - Lemma vals_mod_2_Sn i j n : - i == j %[mod 2^n] -> - i == j %[mod 2^n.+1] \/ i == j + 2^n %[mod 2^n.+1]. - Proof. - intros. - case (boolP (i == j %[mod 2^n.+1])). - - by left. - - right. - apply (vals_modn_mul (expn_2_pos n) (n := 2)) in H; try lia. - destruct H as [? [??]]. - destruct x. - + rewrite mul0n addn0 in H0. - by rewrite expnS mulnC H0 in i0. - + assert (x = 0%N) by lia. - rewrite H1 mul1n in H0. - by rewrite expnS mulnC. - Qed. - - Definition is_partitioned_in_same_bins_by_m_to {n} (v:'rV[G]_(2^n)) m := - (forall i j, val i == val j %[mod 2^m] -> v 0 i = v 0 j). - - Lemma mod_mod_le k m n : - m <= n -> - (k %% 2^n) %% 2^m = k %% 2^m. - Proof. - intros. - rewrite modn_dvdm; trivial. - by apply dvdn_exp2l. - Qed. - - Lemma row_sum_rot_pow_rec_step_narrows_bins {n} (v:'rV[G]_(2^n)) (m : nat) : - S m <= n -> - is_partitioned_in_same_bins_by_m_to v (S m) -> - is_partitioned_in_same_bins_by_m_to (v + rotate_row_right (expn_2_pos n) v (2^m)) m. - Proof. - intro le_Sm_n. - rewrite /is_partitioned_in_same_bins_by_m_to. - intros. - rewrite !mxE /rotate_index_right_ord. - assert (val i == val j %[mod 2^m.+1] \/ val i == val j + 2^m %[mod 2^m.+1]). - { - by apply vals_mod_2_Sn. - } - destruct H1. - - f_equal; apply H; simpl. - + apply H1. - + rewrite mod_mod_le; trivial. - rewrite mod_mod_le; trivial. - move /eqP in H1. - apply /eqP. - apply (f_equal (fun z => (z + 2^m)%N %% (2^m.+1))) in H1. - by rewrite !modnDml in H1. - - move /eqP in H1. - assert (j = i + 2^m %[mod 2^m.+1]). - { - apply (f_equal (fun z => (z + 2^m)%N %% (2^m.+1))) in H1. - rewrite !modnDml in H1. - rewrite H1 -addnA. - replace (2^m + 2^m)%N with (2^m.+1). - - by rewrite modnDr. - - rewrite expnS; lia. - } - rewrite addrC. - f_equal; apply H; simpl. - + by rewrite (mod_mod_le (i + 2^m)%N le_Sm_n) -H2. - + by rewrite (mod_mod_le (j + 2^m)%N le_Sm_n) -H1. - Qed. - - Lemma row_sum_rot_pow_rec_step_preserves_sum {n} (v:'rV[G]_(2^n)) (m : nat) - (pf1:2^m.+1<=2^n) (pf2:2^m<=2^n): - \sum_(i < 2^S m) v 0 (widen_ord pf1 i) = - \sum_(i < 2^m) (v + rotate_row_right (expn_2_pos n) v (2^m)) 0 (widen_ord pf2 i). - Proof. - intros. - under [\sum_(i < 2 ^ m) _] eq_bigr do rewrite !mxE. - - rewrite /= big_split /=. - - have pf1': (2 ^ m + 2 ^ m <= 2 ^ n) by (rewrite (leq_trans _ pf1) // expnS; lia). - transitivity (\sum_(i < 2 ^ m + 2 ^ m) v 0 (widen_ord pf1' i)). - - have: 2 ^ m.+1 = addn (2 ^ m) (2 ^ m) - by (rewrite expnS; lia). - destruct 1. - apply eq_bigr => i _. - by f_equal; apply val_inj. - - rewrite big_split_ord /=. - f_equal; (apply eq_bigr => i _; f_equal; apply val_inj => //=). - rewrite addnC modn_small //. - have leq1: i < 2 ^ m by apply ltn_ord. - lia. - Qed. - - Definition row_sum_rot_pow {n} (v:'rV[G]_(2^n)) := row_sum_rot_pow_rec v n. - - Lemma is_partitioned_in_same_bins_by_m_to0 {n} (v:'rV[G]_(2^n)) : - is_partitioned_in_same_bins_by_m_to v n. - Proof. - move=> i j eqq. - suff ->: i = j => //. - rewrite !modn_small in eqq. - - apply val_inj. - by apply/eqP. - - apply ltn_ord. - - apply ltn_ord. - Qed. - - Lemma row_sum_rot_pow_is_really_binned {n} (v:'rV[G]_(2^n)) : - is_partitioned_in_same_bins_by_m_to (row_sum_rot_pow v) 0. - Proof. - rewrite /row_sum_rot_pow. - suff {v}: forall n', n' <= n -> forall v : 'rV_(2 ^ n), - is_partitioned_in_same_bins_by_m_to v n' -> - is_partitioned_in_same_bins_by_m_to (row_sum_rot_pow_rec v n') 0. - { apply. - - apply leqnn. - - apply is_partitioned_in_same_bins_by_m_to0. - } - induction n' => //= n'l v. - move/row_sum_rot_pow_rec_step_narrows_bins => HH. - apply IHn'. - - by apply ltnW. - - by apply HH. - Qed. - - Lemma row_sum_rot_pow_is_summed {n} (v:'rV[G]_(2^n)) : - \sum_(i < 2^n) v 0 i = (row_sum_rot_pow v) 0 (Ordinal (expn_2_pos n)). - Proof. - rewrite /row_sum_rot_pow. - suff {v} HH: forall n' (pf:2^n'<=2^n), forall v' : 'rV_(2 ^ n), - \sum_(i < 2^n') v' 0 (widen_ord pf i) = - (row_sum_rot_pow_rec v' n') 0 (Ordinal (expn_2_pos n)). - { - rewrite <- (HH n (leqnn _) v). - - apply eq_big => // i _. - f_equal. - by apply ord_inj. - } - induction n' => /= pf v'. - - rewrite (big_pred1_id _ _ _ _ (i:=0)). - + rewrite addr0. - f_equal. - by apply val_inj. - + move=> i /=. - by rewrite ord1 eqxx. - - have pf': is_true (2 ^ n' <= 2 ^ n). - { - rewrite (leq_trans _ pf) // expnS. - lia. - } - rewrite <- (IHn' pf'). - by apply row_sum_rot_pow_rec_step_preserves_sum. - Qed. - - (* claim at kth iteration v is a concatenation of 2^k equal vectors each of which has the same sum as the original v. *) - Lemma row_sum_rot_pow_correct {n} (v:'rV[G]_(2^n)) - : row_sum_rot_pow v = const_mx (\sum_(j < 2^n) v 0 j). - Proof. - apply/matrixP => rr i. - rewrite !ord1 /const_mx !mxE row_sum_rot_pow_is_summed /row_sum_rot_pow . - apply row_sum_rot_pow_is_really_binned. - by rewrite expn0 !modn1. - Qed. - -End add_self_pow. diff --git a/coq/NeuralNetworks/AxiomaticNormedRealVectorSpace.v b/coq/NeuralNetworks/AxiomaticNormedRealVectorSpace.v deleted file mode 100644 index d808c207..00000000 --- a/coq/NeuralNetworks/AxiomaticNormedRealVectorSpace.v +++ /dev/null @@ -1,45 +0,0 @@ -(*using definitions from http://www.math.ucla.edu/~tao/resource/general/121.1.00s/vector_axioms.html *) - -Require Import Reals.Rbase. -Require Import Reals.Rfunctions. - -Module AxiomaticNormedRealVectorSpace. - -Inductive rvector (d: nat) : Set := -| zero -| add (x y : rvector d) -| inverse (x: rvector d) -| smult (r:R) (x: rvector d). - -Class NormedVectorSpace (d: nat) := - { - norm: rvector d -> R; - - rv_axiom_additive : forall x y z : rvector d, - (add d x y) = (add d y x) /\ - (add d (add d x y) z) = (add d x (add d y z)) /\ - (add d (zero d) x) = (add d x (zero d)) /\ - (add d x (zero d)) = x /\ - (add d (inverse d x) x) = (add d x (inverse d x)) /\ - (add d x (inverse d x)) = zero d; - - rv_axiom_multiplicaive : forall (x : rvector d), forall (b c : R), - smult d R0 x = zero d /\ - smult d R1 x = x /\ - smult d (Rmult b c) x = smult d b (smult d c x); - - rv_axiom_distributive : forall (x y : rvector d), forall (a b : R), - smult d b (add d x y) = add d (smult d b x) (smult d b y) /\ - smult d (Rplus a b) x = add d (smult d a x) (smult d b x); - - norm_axiom_zero : forall (x:rvector d), - norm (zero d) = R0 <-> x=zero d; - - norm_axiom_abs : forall (x: rvector d), forall (a:R), - norm (smult d a x) = Rmult (Rabs a) (norm x); - - norm_axiom_add : forall (x y : rvector d), - norm (add d x y) = Rplus (norm x) (norm y); - }. - -End AxiomaticNormedRealVectorSpace. diff --git a/coq/NeuralNetworks/DefinedFunctions.v b/coq/NeuralNetworks/DefinedFunctions.v deleted file mode 100644 index 408fbb8e..00000000 --- a/coq/NeuralNetworks/DefinedFunctions.v +++ /dev/null @@ -1,15581 +0,0 @@ -Require Import Program. -Require Import String. -Require Import EquivDec. -Require Import RelationClasses. -Require Import List. -Require Import Permutation. -Require Import NPeano BinInt PeanoNat. -Require Import Lra Lia. -Require Reals. -Require Import Eqdep_dec. - -Require Import Floatish. -Require Import Utils. -Require Import derivlemmas. -Require Import Vector. - -Import ListNotations. - -Local Open Scope list_scope. -Declare Scope df_scope. - -Set Bullet Behavior "Strict Subproofs". - -Section DefinedFunctions. - - Context {floatish_impl:floatish}. - Local Open Scope float. - -(* in pytorch relu(f)' if f <=0 then 0 else f' *) -(* in pytorch abs(f)' = f'*sign(f) *) -(* max(a,b)' = if a<=b then b' else a' *) -(* min(a,b)' = if a>=b then b' else a' *) -(* x = Variable(torch.tensor(0.0), requires_grad=True) *) -(* z = torch.min(x*x, x); z.backward(); print(x.grad) = 1 *) -(* x.grad.data.zero_() between tests *) -(* relu behaves like max(x, 0), not max(0,x), i.e. relu(x)' at 0 = 0 *) - - - Section Definitions. - - Definition var := string. - - Inductive SubVar : Set := - | Name (s : string) - | Sub (v : SubVar) (i : nat). - - - Definition var_dec : forall v1 v2 : SubVar, {v1 = v2} + {v1 <> v2}. - Proof. - decide equality. - apply string_dec. - apply Nat.eq_dec. - Defined. - - Global Instance var_eqdec : EqDec SubVar eq. - Proof. - intros x y. - apply var_dec. - Defined. - - (* A subset of defined functions *) - - - Inductive definition_function_types - := DTfloat - | DTVector (n:nat) - | DTMatrix (m n:nat). - - Definition definition_function_types_interp (dft:definition_function_types) : Type - := match dft with - | DTfloat => float - | DTVector n => Vector float n - | DTMatrix m n => Matrix float m n - end. - - Inductive data_type : definition_function_types -> Type - := DataFloat : data_type DTfloat - | DataVector n (v:Vector float n) : data_type (DTVector n) - | DataMatrix m n (mat:Matrix float m n) : data_type (DTMatrix m n). - - Definition var_type := (SubVar * definition_function_types)%type. - - Definition definition_function_types_dec : forall v1 v2 : definition_function_types, {v1 = v2} + {v1 <> v2}. - Proof. - decide equality; apply Nat.eq_dec. - Defined. - - Definition vart_dec : forall v1 v2 : var_type, {v1 = v2} + {v1 <> v2}. - Proof. - decide equality. - - apply definition_function_types_dec. - - apply var_dec. - Defined. - - Global Instance vart_eqdec : EqDec var_type eq. - Proof. - intros ??. - apply vart_dec. - Defined. - - Lemma var_type_UIP_refl {x:var_type} (e:x=x) : e = eq_refl x. - Proof. - apply (UIP_dec vart_dec). - Qed. - - Lemma definition_function_types_UIP_refl {x:definition_function_types} (e:x=x) : e = eq_refl x. - Proof. - apply (UIP_dec definition_function_types_dec). - Qed. - - Definition env_entry_type := {v:var_type & definition_function_types_interp (snd v)}. - Definition df_env := list env_entry_type. - - Definition mk_env_entry v e : env_entry_type - := let P := fun xv => definition_function_types_interp (snd xv) in - existT P v e. - - Definition UnitAnn: definition_function_types->Type := fun _ => unit. - Definition EvalAnn: definition_function_types->Type := definition_function_types_interp. - - Inductive DefinedFunction {Ann:definition_function_types->Type} : definition_function_types -> Type := - | Number (ann:Ann DTfloat) (x : float) : DefinedFunction DTfloat - | Constant {t:definition_function_types} (ann:Ann t) (x : definition_function_types_interp t) : DefinedFunction t - | DVector {n} (ann:Ann (DTVector n)) (x : Vector (DefinedFunction DTfloat) n) : DefinedFunction (DTVector n) - | DMatrix {n m} (ann:Ann (DTMatrix n m)) (x : Matrix (DefinedFunction DTfloat) n m) : DefinedFunction (DTMatrix n m) - | Var (v : var_type) (ann: Ann (snd v)) : DefinedFunction (snd v) - | Plus (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat - | Minus (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat - | Times (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat - | Divide (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat - | Square (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat - | Exp (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat - | Log (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat - | Abs (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat - | Sign (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat - | PSign (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat - | Max (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat - | VectorDot {n} (ann:Ann DTfloat) (l r: DefinedFunction (DTVector n)) : DefinedFunction DTfloat - | VectorSum {n} (ann:Ann DTfloat) (v: DefinedFunction (DTVector n)) : DefinedFunction DTfloat - | MatrixSum {m n} (ann:Ann DTfloat) (v: DefinedFunction (DTMatrix m n)) : DefinedFunction DTfloat - | VectorElem {n} (ann:Ann DTfloat) (l:DefinedFunction (DTVector n)) (i:{x:nat|x Prop) - (f : forall (ann : UnitAnn DTfloat) (x : float), - P DTfloat (Number ann x)) - (f0 : forall (t : definition_function_types) - (ann : UnitAnn t) (x : definition_function_types_interp t), P t (Constant ann x)) - (f1 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (x : Vector (DefinedFunction UnitAnn DTfloat) n), - (forall s : {n' : nat | (n' < n)%nat}, P DTfloat (x s)) -> - P (DTVector n) (DVector ann x)) - (f2 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m)) - (x : Matrix (DefinedFunction UnitAnn DTfloat) n m), - (forall (s : {n' : nat | (n' < n)%nat}) (s0 : {m' : nat | (m' < m)%nat}), - P DTfloat (x s s0)) -> P (DTMatrix n m) (DMatrix ann x)) - (f3 : forall (v : var_type) (ann : UnitAnn (snd v)), - P (snd v) (Var v ann)) - (f4 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Plus ann l r)) - (f5 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Minus ann l r)) - (f6 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Times ann l r)) - (f7 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Divide ann l r)) - (f8 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Square ann e)) - (f9 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Exp ann e)) - (f10 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Log ann e)) - (f11 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Abs ann e)) - (f12 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Sign ann e)) - (f13 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (PSign ann e)) - (f14 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Max ann l r)) - (f15 : forall (n : nat) (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn (DTVector n)), - P (DTVector n) l -> - forall r : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) r -> P DTfloat (VectorDot ann l r)) - (f16 : forall (n : nat) (ann : UnitAnn DTfloat) - (v : DefinedFunction UnitAnn (DTVector n)), P (DTVector n) v -> P DTfloat (VectorSum ann v)) - (f17 : forall (m n : nat) (ann : UnitAnn DTfloat) - (v : DefinedFunction UnitAnn (DTMatrix m n)), - P (DTMatrix m n) v -> P DTfloat (MatrixSum ann v)) - (f18 : forall (n : nat) (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn (DTVector n)), - P (DTVector n) l -> - forall i : {x : nat | (x < n)%nat}, P DTfloat (VectorElem ann l i)) - (f19 : forall (m n : nat) (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn (DTMatrix m n)), - P (DTMatrix m n) l -> - forall (i : {x : nat | (x < m)%nat}) (j : {x : nat | (x < n)%nat}), - P DTfloat (MatrixElem ann l i j)) - (f20 : forall (m n : nat) (ann : UnitAnn (DTVector m)) - (l : DefinedFunction UnitAnn (DTMatrix m n)), - P (DTMatrix m n) l -> - forall r : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) r -> P (DTVector m) (MatrixVectorMult ann l r)) - (f21 : forall (m n : nat) (ann : UnitAnn (DTMatrix m n)) - (l : DefinedFunction UnitAnn (DTMatrix m n)), - P (DTMatrix m n) l -> - forall r : DefinedFunction UnitAnn (DTVector m), - P (DTVector m) r -> P (DTMatrix m n) (MatrixVectorAdd ann l r)) - (f22 : forall (m p n : nat) (ann : UnitAnn (DTMatrix m n)) - (l : DefinedFunction UnitAnn (DTMatrix m p)), - P (DTMatrix m p) l -> - forall r : DefinedFunction UnitAnn (DTMatrix p n), - P (DTMatrix p n) r -> P (DTMatrix m n) (MatrixMult ann l r)) - (f23 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (l : DefinedFunction UnitAnn (DTVector n)), - P (DTVector n) l -> - forall r : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) r -> P (DTVector n) (VectorPlus ann l r)) - (f24 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (l : DefinedFunction UnitAnn (DTVector n)), - P (DTVector n) l -> - forall r : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) r -> P (DTVector n) (VectorMinus ann l r)) - (f25 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m)) - (l : DefinedFunction UnitAnn (DTMatrix n m)), - P (DTMatrix n m) l -> - forall r : DefinedFunction UnitAnn (DTMatrix n m), - P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixPlus ann l r)) - (f26 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m)) - (l : DefinedFunction UnitAnn (DTMatrix n m)), - P (DTMatrix n m) l -> - forall r : DefinedFunction UnitAnn (DTMatrix n m), - P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixMinus ann l r)) - (f27 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (x : DefinedFunction UnitAnn DTfloat), - P DTfloat x -> - forall l : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) l -> P (DTVector n) (VectorScalMult ann x l)) - (f28 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m)) - (x : DefinedFunction UnitAnn DTfloat), - P DTfloat x -> - forall l : DefinedFunction UnitAnn (DTMatrix n m), - P (DTMatrix n m) l -> P (DTMatrix n m) (MatrixScalMult ann x l)) - (f29 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (v : SubVar) (s : DefinedFunction UnitAnn DTfloat), - P DTfloat s -> - forall l : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) l -> P (DTVector n) (VectorApply ann v s l)) - (f30 : forall (m n : nat) (ann : UnitAnn (DTMatrix m n)) - (v : SubVar) (s : DefinedFunction UnitAnn DTfloat), - P DTfloat s -> - forall l : DefinedFunction UnitAnn (DTMatrix m n), - P (DTMatrix m n) l -> P (DTMatrix m n) (MatrixApply ann v s l)) - (f31 : forall (n : nat) (ann : UnitAnn DTfloat) - (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat), - P DTfloat s -> - forall l : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) l -> forall r : Vector float n, P DTfloat (VLossfun ann v1 v2 s l r)) - (f32 : forall (m n : nat) (ann : UnitAnn DTfloat) - (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat), - P DTfloat s -> - forall l : DefinedFunction UnitAnn (DTMatrix m n), - P (DTMatrix m n) l -> - forall r : Matrix float m n, P DTfloat (MLossfun ann v1 v2 s l r)) - := -fix -F (d : definition_function_types) - (d0 : DefinedFunction UnitAnn d) {struct d0} : P d d0 := - match d0 as d2 in (DefinedFunction _ d1) return (P d1 d2) with - | Number ann x => f ann x - | @Constant _ t ann x => f0 t ann x - | @DVector _ n ann x => f1 n ann x (fun s : {n' : nat | (n' < n)%nat} => F DTfloat (x s)) - | @DMatrix _ n m ann x => - f2 n m ann x - (fun (s : {n' : nat | (n' < n)%nat}) (s0 : {m' : nat | (m' < m)%nat}) => - F DTfloat (x s s0)) - | Var v ann => f3 v ann - | Plus ann l r => f4 ann l (F DTfloat l) r (F DTfloat r) - | Minus ann l r => f5 ann l (F DTfloat l) r (F DTfloat r) - | Times ann l r => f6 ann l (F DTfloat l) r (F DTfloat r) - | Divide ann l r => f7 ann l (F DTfloat l) r (F DTfloat r) - | Square ann e => f8 ann e (F DTfloat e) - | Exp ann e => f9 ann e (F DTfloat e) - | Log ann e => f10 ann e (F DTfloat e) - | Abs ann e => f11 ann e (F DTfloat e) - | Sign ann e => f12 ann e (F DTfloat e) - | PSign ann e => f13 ann e (F DTfloat e) - | Max ann l r => f14 ann l (F DTfloat l) r (F DTfloat r) - | @VectorDot _ n ann l r => f15 n ann l (F (DTVector n) l) r (F (DTVector n) r) - | @VectorSum _ n ann v => f16 n ann v (F (DTVector n) v) - | @MatrixSum _ m n ann v => f17 m n ann v (F (DTMatrix m n) v) - | @VectorElem _ n ann l i => f18 n ann l (F (DTVector n) l) i - | @MatrixElem _ m n ann l i j => f19 m n ann l (F (DTMatrix m n) l) i j - | @MatrixVectorMult _ m n ann l r => - f20 m n ann l (F (DTMatrix m n) l) r (F (DTVector n) r) - | @MatrixVectorAdd _ m n ann l r => - f21 m n ann l (F (DTMatrix m n) l) r (F (DTVector m) r) - | @MatrixMult _ m p n ann l r => - f22 m p n ann l (F (DTMatrix m p) l) r (F (DTMatrix p n) r) - | @VectorPlus _ n ann l r => f23 n ann l (F (DTVector n) l) r (F (DTVector n) r) - | @VectorMinus _ n ann l r => f24 n ann l (F (DTVector n) l) r (F (DTVector n) r) - | @MatrixPlus _ n m ann l r => f25 n m ann l (F (DTMatrix n m) l) r (F (DTMatrix n m) r) - | @MatrixMinus _ n m ann l r => - f26 n m ann l (F (DTMatrix n m) l) r (F (DTMatrix n m) r) - | @VectorScalMult _ n ann x l => f27 n ann x (F DTfloat x) l (F (DTVector n) l) - | @MatrixScalMult _ n m ann x l => f28 n m ann x (F DTfloat x) l (F (DTMatrix n m) l) - | @VectorApply _ n ann v s l => f29 n ann v s (F DTfloat s) l (F (DTVector n) l) - | @MatrixApply _ m n ann v s l => - f30 m n ann v s (F DTfloat s) l (F (DTMatrix m n) l) - | @VLossfun _ n ann v1 v2 s l r => - f31 n ann v1 v2 s (F DTfloat s) l (F (DTVector n) l) r - | @MLossfun _ m n ann v1 v2 s l r => - f32 m n ann v1 v2 s (F DTfloat s) l (F (DTMatrix m n) l) r - end. - -Definition DefinedFunction_ind_simpl {Ann} - (P : forall (d : definition_function_types), DefinedFunction Ann d -> Prop) - (f : forall (ann : Ann DTfloat) (x : float), - P DTfloat (Number ann x)) - (f0 : forall (t : definition_function_types) - (ann : Ann t) (x : definition_function_types_interp t), P t (Constant ann x)) - (f1 : forall (n : nat) (ann : Ann (DTVector n)) - (x : Vector (DefinedFunction Ann DTfloat) n), - (forall s : {n' : nat | (n' < n)%nat}, P DTfloat (x s)) -> - P (DTVector n) (DVector ann x)) - (f2 : forall (n m : nat) (ann : Ann (DTMatrix n m)) - (x : Matrix (DefinedFunction Ann DTfloat) n m), - (forall (s : {n' : nat | (n' < n)%nat}) (s0 : {m' : nat | (m' < m)%nat}), - P DTfloat (x s s0)) -> P (DTMatrix n m) (DMatrix ann x)) - (f3 : forall (v : var_type) (ann : Ann (snd v)), - P (snd v) (Var v ann)) - (f4 : forall (ann : Ann DTfloat) - (l : DefinedFunction Ann DTfloat), - P DTfloat l -> - forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Plus ann l r)) - (f5 : forall (ann : Ann DTfloat) - (l : DefinedFunction Ann DTfloat), - P DTfloat l -> - forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Minus ann l r)) - (f6 : forall (ann : Ann DTfloat) - (l : DefinedFunction Ann DTfloat), - P DTfloat l -> - forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Times ann l r)) - (f7 : forall (ann : Ann DTfloat) - (l : DefinedFunction Ann DTfloat), - P DTfloat l -> - forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Divide ann l r)) - (f8 : forall (ann : Ann DTfloat) - (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Square ann e)) - (f9 : forall (ann : Ann DTfloat) - (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Exp ann e)) - (f10 : forall (ann : Ann DTfloat) - (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Log ann e)) - (f11 : forall (ann : Ann DTfloat) - (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Abs ann e)) - (f12 : forall (ann : Ann DTfloat) - (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Sign ann e)) - (f13 : forall (ann : Ann DTfloat) - (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (PSign ann e)) - (f14 : forall (ann : Ann DTfloat) - (l : DefinedFunction Ann DTfloat), - P DTfloat l -> - forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Max ann l r)) - (f15 : forall (n : nat) (ann : Ann DTfloat) - (l : DefinedFunction Ann (DTVector n)), - P (DTVector n) l -> - forall r : DefinedFunction Ann (DTVector n), - P (DTVector n) r -> P DTfloat (VectorDot ann l r)) - (f16 : forall (n : nat) (ann : Ann DTfloat) - (v : DefinedFunction Ann (DTVector n)), P (DTVector n) v -> P DTfloat (VectorSum ann v)) - (f17 : forall (m n : nat) (ann : Ann DTfloat) - (v : DefinedFunction Ann (DTMatrix m n)), - P (DTMatrix m n) v -> P DTfloat (MatrixSum ann v)) - (f18 : forall (n : nat) (ann : Ann DTfloat) - (l : DefinedFunction Ann (DTVector n)), - P (DTVector n) l -> - forall i : {x : nat | (x < n)%nat}, P DTfloat (VectorElem ann l i)) - (f19 : forall (m n : nat) (ann : Ann DTfloat) - (l : DefinedFunction Ann (DTMatrix m n)), - P (DTMatrix m n) l -> - forall (i : {x : nat | (x < m)%nat}) (j : {x : nat | (x < n)%nat}), - P DTfloat (MatrixElem ann l i j)) - (f20 : forall (m n : nat) (ann : Ann (DTVector m)) - (l : DefinedFunction Ann (DTMatrix m n)), - P (DTMatrix m n) l -> - forall r : DefinedFunction Ann (DTVector n), - P (DTVector n) r -> P (DTVector m) (MatrixVectorMult ann l r)) - (f21 : forall (m n : nat) (ann : Ann (DTMatrix m n)) - (l : DefinedFunction Ann (DTMatrix m n)), - P (DTMatrix m n) l -> - forall r : DefinedFunction Ann (DTVector m), - P (DTVector m) r -> P (DTMatrix m n) (MatrixVectorAdd ann l r)) - (f22 : forall (m p n : nat) (ann : Ann (DTMatrix m n)) - (l : DefinedFunction Ann (DTMatrix m p)), - P (DTMatrix m p) l -> - forall r : DefinedFunction Ann (DTMatrix p n), - P (DTMatrix p n) r -> P (DTMatrix m n) (MatrixMult ann l r)) - (f23 : forall (n : nat) (ann : Ann (DTVector n)) - (l : DefinedFunction Ann (DTVector n)), - P (DTVector n) l -> - forall r : DefinedFunction Ann (DTVector n), - P (DTVector n) r -> P (DTVector n) (VectorPlus ann l r)) - (f24 : forall (n : nat) (ann : Ann (DTVector n)) - (l : DefinedFunction Ann (DTVector n)), - P (DTVector n) l -> - forall r : DefinedFunction Ann (DTVector n), - P (DTVector n) r -> P (DTVector n) (VectorMinus ann l r)) - (f25 : forall (n m : nat) (ann : Ann (DTMatrix n m)) - (l : DefinedFunction Ann (DTMatrix n m)), - P (DTMatrix n m) l -> - forall r : DefinedFunction Ann (DTMatrix n m), - P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixPlus ann l r)) - (f26 : forall (n m : nat) (ann : Ann (DTMatrix n m)) - (l : DefinedFunction Ann (DTMatrix n m)), - P (DTMatrix n m) l -> - forall r : DefinedFunction Ann (DTMatrix n m), - P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixMinus ann l r)) - (f27 : forall (n : nat) (ann : Ann (DTVector n)) - (x : DefinedFunction Ann DTfloat), - P DTfloat x -> - forall l : DefinedFunction Ann (DTVector n), - P (DTVector n) l -> P (DTVector n) (VectorScalMult ann x l)) - (f28 : forall (n m : nat) (ann : Ann (DTMatrix n m)) - (x : DefinedFunction Ann DTfloat), - P DTfloat x -> - forall l : DefinedFunction Ann (DTMatrix n m), - P (DTMatrix n m) l -> P (DTMatrix n m) (MatrixScalMult ann x l)) - (f29 : forall (n : nat) (ann : Ann (DTVector n)) - (v : SubVar) (s : DefinedFunction UnitAnn DTfloat), - forall l : DefinedFunction Ann (DTVector n), - P (DTVector n) l -> P (DTVector n) (VectorApply ann v s l)) - (f30 : forall (m n : nat) (ann : Ann (DTMatrix m n)) - (v : SubVar) (s : DefinedFunction UnitAnn DTfloat), - forall l : DefinedFunction Ann (DTMatrix m n), - P (DTMatrix m n) l -> P (DTMatrix m n) (MatrixApply ann v s l)) - (f31 : forall (n : nat) (ann : Ann DTfloat) - (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat), - forall l : DefinedFunction Ann (DTVector n), - P (DTVector n) l -> forall r : Vector float n, P DTfloat (VLossfun ann v1 v2 s l r)) - (f32 : forall (m n : nat) (ann : Ann DTfloat) - (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat), - forall l : DefinedFunction Ann (DTMatrix m n), - P (DTMatrix m n) l -> - forall r : Matrix float m n, P DTfloat (MLossfun ann v1 v2 s l r)) - := -fix -F (d : definition_function_types) - (d0 : DefinedFunction Ann d) {struct d0} : P d d0 := - match d0 as d2 in (DefinedFunction _ d1) return (P d1 d2) with - | Number ann x => f ann x - | @Constant _ t ann x => f0 t ann x - | @DVector _ n ann x => f1 n ann x (fun s : {n' : nat | (n' < n)%nat} => F DTfloat (x s)) - | @DMatrix _ n m ann x => - f2 n m ann x - (fun (s : {n' : nat | (n' < n)%nat}) (s0 : {m' : nat | (m' < m)%nat}) => - F DTfloat (x s s0)) - | Var v ann => f3 v ann - | Plus ann l r => f4 ann l (F DTfloat l) r (F DTfloat r) - | Minus ann l r => f5 ann l (F DTfloat l) r (F DTfloat r) - | Times ann l r => f6 ann l (F DTfloat l) r (F DTfloat r) - | Divide ann l r => f7 ann l (F DTfloat l) r (F DTfloat r) - | Square ann e => f8 ann e (F DTfloat e) - | Exp ann e => f9 ann e (F DTfloat e) - | Log ann e => f10 ann e (F DTfloat e) - | Abs ann e => f11 ann e (F DTfloat e) - | Sign ann e => f12 ann e (F DTfloat e) - | PSign ann e => f13 ann e (F DTfloat e) - | Max ann l r => f14 ann l (F DTfloat l) r (F DTfloat r) - | @VectorDot _ n ann l r => f15 n ann l (F (DTVector n) l) r (F (DTVector n) r) - | @VectorSum _ n ann v => f16 n ann v (F (DTVector n) v) - | @MatrixSum _ m n ann v => f17 m n ann v (F (DTMatrix m n) v) - | @VectorElem _ n ann l i => f18 n ann l (F (DTVector n) l) i - | @MatrixElem _ m n ann l i j => f19 m n ann l (F (DTMatrix m n) l) i j - | @MatrixVectorMult _ m n ann l r => - f20 m n ann l (F (DTMatrix m n) l) r (F (DTVector n) r) - | @MatrixVectorAdd _ m n ann l r => - f21 m n ann l (F (DTMatrix m n) l) r (F (DTVector m) r) - | @MatrixMult _ m p n ann l r => - f22 m p n ann l (F (DTMatrix m p) l) r (F (DTMatrix p n) r) - | @VectorPlus _ n ann l r => f23 n ann l (F (DTVector n) l) r (F (DTVector n) r) - | @VectorMinus _ n ann l r => f24 n ann l (F (DTVector n) l) r (F (DTVector n) r) - | @MatrixPlus _ n m ann l r => f25 n m ann l (F (DTMatrix n m) l) r (F (DTMatrix n m) r) - | @MatrixMinus _ n m ann l r => - f26 n m ann l (F (DTMatrix n m) l) r (F (DTMatrix n m) r) - | @VectorScalMult _ n ann x l => f27 n ann x (F DTfloat x) l (F (DTVector n) l) - | @MatrixScalMult _ n m ann x l => f28 n m ann x (F DTfloat x) l (F (DTMatrix n m) l) - | @VectorApply _ n ann v s l => f29 n ann v s l (F (DTVector n) l) - | @MatrixApply _ m n ann v s l => - f30 m n ann v s l (F (DTMatrix m n) l) - | @VLossfun _ n ann v1 v2 s l r => - f31 n ann v1 v2 s l (F (DTVector n) l) r - | @MLossfun _ m n ann v1 v2 s l r => - f32 m n ann v1 v2 s l (F (DTMatrix m n) l) r - end. - - Definition get_annotation {Ann T} (df:DefinedFunction Ann T) : Ann T - := match df with - | Number ann _ => ann - | Constant _ ann _ => ann - | DVector _ ann _ => ann - | DMatrix _ _ ann _ => ann - | Var _ ann => ann - | Plus ann _ _ => ann - | Minus ann _ _ => ann - | Times ann _ _ => ann - | Divide ann _ _ => ann - | Square ann _ => ann - | Exp ann _ => ann - | Log ann _ => ann - | Abs ann _ => ann - | Sign ann _ => ann - | PSign ann _ => ann - | Max ann _ _ => ann - | VectorDot _ ann _ _ => ann - | VectorSum _ ann _ => ann - | MatrixSum _ _ ann _ => ann - | VectorElem _ ann _ _ => ann - | MatrixElem _ _ ann _ _ _ => ann - | MatrixVectorMult _ _ ann _ _ => ann - | MatrixVectorAdd _ _ ann _ _ => ann - | MatrixMult _ _ _ ann _ _ => ann - | VectorPlus _ ann _ _ => ann - | VectorMinus _ ann _ _ => ann - | MatrixPlus _ _ ann _ _ => ann - | MatrixMinus _ _ ann _ _ => ann - | VectorScalMult _ ann _ _ => ann - | MatrixScalMult _ _ ann _ _ => ann - | VectorApply _ ann _ _ _ => ann - | MatrixApply _ _ ann _ _ _ => ann - | VLossfun _ ann _ _ _ _ _ => ann - | MLossfun _ _ ann _ _ _ _ _ => ann - end. - - Definition dft_eq_dec : - forall (t1 t2 : definition_function_types), {t1 = t2} + {t1 <> t2}. - Proof. - decide equality. - decide equality. - apply Nat.eq_dec. - apply Nat.eq_dec. - Defined. - - Global Instance dft_eqdec : EqDec definition_function_types eq. - Proof. - intros ??. - apply dft_eq_dec. - Defined. - - End Definitions. - - Tactic Notation "DefinedFunction_cases" tactic(first) ident(c) := - first; - [ Case_aux c "Number"%string - | Case_aux c "Constant"%string - | Case_aux c "DVector"%string - | Case_aux c "DMatrix"%string - | Case_aux c "Var"%string - | Case_aux c "Plus"%string - | Case_aux c "Minus"%string - | Case_aux c "Times"%string - | Case_aux c "Divide"%string - | Case_aux c "Square"%string - | Case_aux c "Exp"%string - | Case_aux c "Log"%string - | Case_aux c "Abs"%string - | Case_aux c "Sign"%string - | Case_aux c "PSign"%string - | Case_aux c "Max"%string - | Case_aux c "VectorDot"%string - | Case_aux c "VectorSum"%string - | Case_aux c "MatrixSum"%string - | Case_aux c "VectorElem"%string - | Case_aux c "MatrixElem"%string - | Case_aux c "MatrixVectorMult"%string - | Case_aux c "MatrixVectorAdd"%string - | Case_aux c "MatrixMult"%string - | Case_aux c "VectorPlus"%string - | Case_aux c "VectorMinus"%string - | Case_aux c "MatrixPlus"%string - | Case_aux c "MatrixMinus"%string - | Case_aux c "VectorScalMult"%string - | Case_aux c "MatrixScalMult"%string - | Case_aux c "VectorApply"%string - | Case_aux c "MatrixApply"%string - | Case_aux c "VLossfun"%string - | Case_aux c "MLossfun"%string]. - - - Ltac refl_simpler := - repeat - match goal with - | [H: @eq var_type _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H) - | [H: @equiv var_type _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H) - | [H: @eq definition_function_types _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H) - | [H: @equiv definition_function_types _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H) - end. - - - Definition df_plus (df1 df2 : DefinedFunction UnitAnn DTfloat) : DefinedFunction UnitAnn DTfloat := - Plus tt df1 df2. - - Definition df_times (df1 df2 : DefinedFunction UnitAnn DTfloat) : DefinedFunction UnitAnn DTfloat := - Times tt df1 df2. - - Definition defined_sum {m} (v:Vector (DefinedFunction UnitAnn DTfloat) m) : DefinedFunction UnitAnn DTfloat - := vector_fold_right1 (fun a b => Plus tt a b) (Number tt 0) id v. - - Definition vsum {m:nat} (v:Vector float m) : float - := vector_fold_right1 Fplus 0 id v. - - Definition msum {m n:nat} (v:Matrix float m n) : float := - vsum (vmap vsum v). - - Definition matrix_vector_mult {n m} (l : Matrix float n m)(r : Vector float m) : Vector float n := - fun i => vsum (fun j => (l i j) * (r j)). - - Definition matrix_vector_add {n m} (l : Matrix float n m) (r : Vector float n) : Matrix float n m := fun i j => (l i j) + (r i). - - Definition matrix_mult {n m p} (l : Matrix float n m)(r : Matrix float m p) : Matrix float n p := - fun i k => vsum (fun j => (l i j) * (r j k)). - - - Section deriv. - - - Section subst. - - Definition substvar {Ann} (v vv:var_type) (e':DefinedFunction Ann (snd v)) (e:DefinedFunction Ann (snd vv)) : (DefinedFunction Ann (snd vv)) := - match snd v == snd vv with - | left pf => eq_rect _ (fun t => DefinedFunction Ann t) e' _ pf - | right pf => e - end. - - Fixpoint df_subst {T Ann} (df: DefinedFunction Ann T) (v:var_type) (e':DefinedFunction UnitAnn (snd v)) := - match df with - | Number _ x => Number tt x - | Constant t _ x => Constant tt x - | DVector n _ df => DVector tt (fun x => df_subst (df x) v e') - | DMatrix n m _ df => DMatrix tt (fun i j => df_subst (df i j) v e') - | Var vvar _ => substvar v vvar e' (Var vvar tt) - | Plus _ l r => Plus tt (df_subst l v e') (df_subst r v e') - | Times _ l r => Times tt (df_subst l v e') (df_subst r v e') - | Minus _ l r => Minus tt (df_subst l v e') (df_subst r v e') - | Divide _ l r => Divide tt (df_subst l v e') (df_subst r v e') - | Square _ e => Square tt (df_subst e v e') - | Exp _ e => Exp tt (df_subst e v e') - | Log _ e => Log tt (df_subst e v e') - | Abs _ e => Abs tt (df_subst e v e') - | Sign _ e => Sign tt (df_subst e v e') - | PSign _ e => PSign tt (df_subst e v e') - | Max _ l r => Max tt (df_subst l v e') (df_subst r v e') - | VectorElem n _ l i => VectorElem tt (df_subst l v e') i - | MatrixElem m n _ l i j => MatrixElem tt (df_subst l v e') i j - | VectorDot n _ l r => - VectorDot tt (df_subst l v e') (df_subst r v e') - | VectorSum n _ e => - VectorSum tt (df_subst e v e') - | MatrixSum n m _ e => - MatrixSum tt (df_subst e v e') - | VectorScalMult n _ x r => - VectorScalMult tt (df_subst x v e') (df_subst r v e') - | MatrixScalMult n m _ x r => - MatrixScalMult tt (df_subst x v e') (df_subst r v e') - | MatrixVectorMult n m _ l r => - MatrixVectorMult tt (df_subst l v e') (df_subst r v e') - | MatrixVectorAdd n m _ l r => - MatrixVectorAdd tt (df_subst l v e') (df_subst r v e') - | MatrixMult n m p _ l r => - MatrixMult tt (df_subst l v e') (df_subst r v e') - | VectorPlus n _ l r => - VectorPlus tt (df_subst l v e') (df_subst r v e') - | VectorMinus n _ l r => - VectorMinus tt (df_subst l v e') (df_subst r v e') - | MatrixPlus n m _ l r => - MatrixPlus tt (df_subst l v e') (df_subst r v e') - | MatrixMinus n m _ l r => - MatrixMinus tt (df_subst l v e') (df_subst r v e') - | VectorApply n _ x s l => - VectorApply tt x s (df_subst l v e') - | MatrixApply n m _ x s l => - MatrixApply tt x s (df_subst l v e') - | VLossfun n _ v1 v2 s l r => - VLossfun tt v1 v2 s (df_subst l v e') r - | MLossfun n m _ v1 v2 s l r => - MLossfun tt v1 v2 s (df_subst l v e') r - end. - - Definition df_substp {T Ann} := - fun e (ve':{v:var_type & DefinedFunction UnitAnn (snd v)}) => - @df_subst T Ann e (projT1 ve') (projT2 ve'). - - Definition df_subst_list {T} (e:DefinedFunction UnitAnn T) - (l:list {v:var_type & DefinedFunction UnitAnn (snd v)}) : DefinedFunction UnitAnn T - := fold_left (@df_substp T UnitAnn) l e. - - End subst. - - - (* restrict to scalar v? *) - - Fixpoint df_deriv {T} (df:DefinedFunction UnitAnn T) (v:var_type) {struct df} : DefinedFunction UnitAnn T - := (match df with - | Number _ _ => Number tt 0 - | Constant t _ x => Constant tt - match t return definition_function_types_interp t with - | DTfloat => 0 - | DTVector n => ConstVector n 0 - | DTMatrix m n => ConstMatrix m n 0 - end - | DVector n _ df => DVector tt (fun x => df_deriv (df x) v) - | DMatrix n m _ df => DMatrix tt (fun i j => df_deriv (df i j) v) - | Var x _ => Constant tt - match snd x as y return definition_function_types_interp y with - | DTfloat => if x == v then 1 else 0 - | DTVector n => ConstVector n (if x == v then 1 else 0) - | DTMatrix m n => ConstMatrix m n (if x == v then 1 else 0) - end - | Plus _ l r => Plus tt (df_deriv l v) (df_deriv r v) - | Minus _ l r => Minus tt (df_deriv l v) (df_deriv r v) - | Times _ l r => Plus tt (Times tt l (df_deriv r v)) - (Times tt (df_deriv l v) r) - | Divide _ l r => Minus tt - (Divide tt (df_deriv l v) r) - (Divide tt (Times tt l (df_deriv r v)) - (Times tt r r)) - | Square _ e => Times tt - (Times tt (Number tt 2) e) (df_deriv e v) - | Exp _ e => Times tt (df_deriv e v) (Exp tt e) - | Log _ e => Divide tt (df_deriv e v) e - | Abs _ e => Times tt (df_deriv e v) (Sign tt e) - | Sign _ e => Number tt 0 - | PSign _ e => Number tt 0 - | Max _ l r => Divide tt - (Plus tt - (Times tt (Minus tt - (df_deriv r v) - (df_deriv l v)) - (PSign tt (Minus tt r l))) - (Plus tt (df_deriv r v) (df_deriv l v))) - (Number tt 2) - | VectorElem n _ l i => VectorElem tt (df_deriv l v) i - | MatrixElem m n _ l i j => MatrixElem tt (df_deriv l v) i j - | VectorDot n _ l r => - let ll := df_deriv l v in - let rr := df_deriv r v in - Plus tt (VectorDot tt ll r) (VectorDot tt l rr) - | VectorSum n _ l => - let ll := df_deriv l v in - VectorSum tt ll - | MatrixSum m n _ l => - let ll := df_deriv l v in - MatrixSum tt ll - | VectorScalMult n _ x r => - let xx := df_deriv x v in - let rr := df_deriv r v in - VectorPlus tt - (VectorScalMult tt xx r) - (VectorScalMult tt x rr) - | MatrixScalMult n m _ x r => - let xx := df_deriv x v in - let rr := df_deriv r v in - MatrixPlus tt - (MatrixScalMult tt xx r) (MatrixScalMult tt x rr) - | MatrixVectorMult n m _ l r => - let ll := df_deriv l v in - let rr := df_deriv r v in - VectorPlus tt (MatrixVectorMult tt ll r) - (MatrixVectorMult tt l rr) - | MatrixVectorAdd n m _ l r => - let ll := df_deriv l v in - let rr := df_deriv r v in - MatrixVectorAdd tt ll rr - | MatrixMult n m p _ l r => - let ll := df_deriv l v in - let rr := df_deriv r v in - MatrixPlus tt (MatrixMult tt ll r) (MatrixMult tt l rr) - | VectorPlus n _ l r => - let ll := df_deriv l v in - let rr := df_deriv r v in - VectorPlus tt ll rr - | VectorMinus n _ l r => - let ll := df_deriv l v in - let rr := df_deriv r v in - VectorMinus tt ll rr - | MatrixPlus n m _ l r => - let ll := df_deriv l v in - let rr := df_deriv r v in - MatrixPlus tt ll rr - | MatrixMinus n m _ l r => - let ll := df_deriv l v in - let rr := df_deriv r v in - MatrixMinus tt ll rr - | VectorApply n _ x s r => - let rr := df_deriv r v in - let ss := df_deriv s (x, DTfloat) in - DVector tt (fun i => Times tt (VectorElem tt rr i) (df_subst ss (x, DTfloat) (VectorElem tt r i))) - | MatrixApply n m _ x s r => - let rr := df_deriv r v in - let ss := df_deriv s (x, DTfloat) in - DMatrix tt (fun i j => Times tt (MatrixElem tt rr i j) (df_subst ss (x, DTfloat) (MatrixElem tt r i j))) - | VLossfun n _ v1 v2 s l r => - let ll := df_deriv l v in - let ss := df_deriv s (v1, DTfloat) in - VectorDot tt ll - (DVector tt (fun i => - df_subst (df_subst ss (v1, DTfloat) (VectorElem tt l i)) - (v2, DTfloat) (Number tt (r i)))) - | MLossfun n m _ v1 v2 s l r => - let ll := df_deriv l v in - let ss := df_deriv s (v1, DTfloat) in - MatrixSum tt - (DMatrix tt - (fun i j => - (Divide tt - (Times tt (MatrixElem tt ll i j) - (df_subst (df_subst ss (v1, DTfloat) (MatrixElem tt l i j)) - (v2, DTfloat) (Number tt (r i j)))) - (Number tt (FfromZ (Z.of_nat m)))))) - end). - - Definition df_gradient {T} (df:DefinedFunction UnitAnn T) (lv:list var_type) : list (DefinedFunction UnitAnn T) - := map (df_deriv df) lv. - - End deriv. - - Section eval. - - Program - Fixpoint vartlookup (l:df_env) (a:var_type) : - option (definition_function_types_interp (snd a)) - := match l with - | nil => None - | fv::os => if a == (projT1 fv) then - Some (eq_rect _ definition_function_types_interp (projT2 fv) _ _) - else vartlookup os a - end. - - Fixpoint vart_update (l:df_env) (a:var_type) (n:definition_function_types_interp (snd a)) : df_env - := match l with - | nil => (mk_env_entry a n)::nil - | fv::os => if a == (projT1 fv) then - (mk_env_entry a n)::os else fv::(vart_update os a n) - end. - - Fixpoint df_eval {T Ann} (σ:df_env) (df:DefinedFunction Ann T) : option (definition_function_types_interp T) - := match df with - | Number _ r => Some r - | Constant t _ x => Some x - | DVector n _ dfs => vectoro_to_ovector (fun i => df_eval σ (dfs i)) - | DMatrix n m _ df => matrixo_to_omatrix (fun i j => df_eval σ (df i j)) - | Var x _ => vartlookup σ x - | Plus _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (l' + r') - | _, _ => None - end - | Minus _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (l' - r') - | _, _ => None - end - | Times _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (l' * r') - | _, _ => None - end - | Divide _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (l' / r') - | _, _ => None - end - | Square _ e => - match df_eval σ e with - | Some v => Some (v * v) - | _ => None - end - | Exp _ e => - match df_eval σ e with - | Some v => Some (Fexp v) - | _ => None - end - | Log _ e => - match df_eval σ e with - | Some v => Some (Fln v) - | _ => None - end - | Abs _ e => - match df_eval σ e with - | Some v => Some (Fabs v) - | _ => None - end - | Sign _ e => - match df_eval σ e with - | Some v => Some (sign v) - | _ => None - end - | PSign _ e => - match df_eval σ e with - | Some v => Some (pos_sign v) - | _ => None - end - | Max _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (Fmax l' r') - | _, _ => None - end - | VectorElem n _ l i => - match (df_eval σ l) with - | Some l' => Some (l' i) - | _ => None - end - | MatrixElem m n _ l i j => - match (df_eval σ l) with - | Some l' => Some (l' i j) - | _ => None - end - | VectorDot n _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (vsum (fun i => (l' i) * (r' i))) - | _, _ => None - end - | VectorSum n _ l => - match df_eval σ l with - | Some l' => Some (vsum l') - | _ => None - end - | MatrixSum n m _ l => - match df_eval σ l with - | Some l' => Some (msum l') - | _ => None - end - | VectorScalMult n _ x r => - match df_eval σ x, df_eval σ r with - | Some x', Some r' => Some (fun j => x' * (r' j)) - | _, _ => None - end - | MatrixScalMult n m _ x r => - match df_eval σ x, df_eval σ r with - | Some x', Some r' => Some (fun i j => x' * (r' i j)) - | _, _ => None - end - | MatrixVectorMult n m _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (matrix_vector_mult l' r') - | _, _ => None - end - | MatrixVectorAdd n m _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (matrix_vector_add l' r') - | _, _ => None - end - | MatrixMult n m p _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (matrix_mult l' r') - | _, _ => None - end - | VectorPlus n _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (fun i => (l' i) + (r' i)) - | _, _ => None - end - | VectorMinus n _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (fun i => (l' i) - (r' i)) - | _, _ => None - end - | MatrixPlus n m _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (fun i j => (l' i j) + (r' i j)) - | _, _ => None - end - | MatrixMinus n m _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (fun i j => (l' i j) - (r' i j)) - | _, _ => None - end - | VectorApply n _ x s r => - match df_eval σ r with - | Some r' => vectoro_to_ovector - (fun i => - let xv := (x, DTfloat):var_type in - df_eval (cons (mk_env_entry xv (r' i)) nil) s) - | _ => None - end - | MatrixApply n m _ x s r => - match df_eval σ r with - | Some r' => matrixo_to_omatrix - (fun i j => - let xv := (x, DTfloat):var_type in - df_eval (cons (mk_env_entry xv (r' i j)) nil) s) - | _ => None - end - | VLossfun n _ v1 v2 s l r => - match df_eval σ l with - | Some l' => - match (vectoro_to_ovector - (fun i => - let xv1 := (v1,DTfloat):var_type in - let xv2 := (v2,DTfloat):var_type in - df_eval (cons (mk_env_entry xv1 (l' i)) - (cons (mk_env_entry xv2 (r i)) nil)) s)) with - | Some vv => Some (vsum vv) - | _ => None - end - | _ => None - end - | MLossfun n m _ v1 v2 s l r => - match df_eval σ l with - | Some l' => - match (matrixo_to_omatrix - (fun i j => - let xv1 := (v1,DTfloat):var_type in - let xv2 := (v2,DTfloat):var_type in - df_eval (cons (mk_env_entry xv1 (l' i j)) - (cons (mk_env_entry xv2 (r i j)) nil)) s)) with - | Some vv => Some ((msum vv) / (FfromZ (Z.of_nat m))) - | _ => None - end - | _ => None - end - - end. - - Fixpoint df_eval_tree {T Ann} (σ:df_env) (df:DefinedFunction Ann T) : option (DefinedFunction EvalAnn T) - := match df with - | Number _ r => Some (Number r r) - | Constant t _ x => Some (Constant x x) - | DVector n _ dfs => - match vectoro_to_ovector (fun i => df_eval_tree σ (dfs i)) with - | Some val => Some (DVector (vmap get_annotation val) val) - | _ => None - end - | DMatrix n m _ df => - match matrixo_to_omatrix (fun i j => df_eval_tree σ (df i j)) with - | Some val => Some (DMatrix - (vmap (fun x => vmap get_annotation x) val) val) - | _ => None - end - | Var x _ => match vartlookup σ x with - | Some val => Some (Var x val) - | _ => None - end - | Plus _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => Some (Plus ((get_annotation l') + (get_annotation r')) - l' r') - | _, _ => None - end - | Minus _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => Some (Minus ((get_annotation l') - (get_annotation r')) - l' r') - | _, _ => None - end - | Times _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => Some (Times ((get_annotation l') * (get_annotation r')) - l' r') - | _, _ => None - end - | Divide _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => Some (Divide ((get_annotation l') / (get_annotation r')) - l' r') - | _, _ => None - end - | Square _ e => - match df_eval_tree σ e with - | Some vv => let v := get_annotation vv in Some (Square (v * v) vv) - | _ => None - end - | Exp _ e => - match df_eval_tree σ e with - | Some vv => Some (Exp (Fexp (get_annotation vv)) vv) - | _ => None - end - | Log _ e => - match df_eval_tree σ e with - | Some vv => Some (Log (Fln (get_annotation vv)) vv) - | _ => None - end - | Abs _ e => - match df_eval_tree σ e with - | Some vv => Some (Abs (Fabs (get_annotation vv)) vv) - | _ => None - end - | Sign _ e => - match df_eval_tree σ e with - | Some vv => Some (Sign (sign (get_annotation vv)) vv) - | _ => None - end - | PSign _ e => - match df_eval_tree σ e with - | Some vv => Some (PSign (pos_sign (get_annotation vv)) vv) - | _ => None - end - | Max _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => Some (Max (Fmax (get_annotation l') (get_annotation r')) - l' r') - | _, _ => None - end - | VectorElem n _ l i => - match (df_eval_tree σ l) with - | Some l' => let vl' := get_annotation l' in - Some (VectorElem (vl' i) l' i) - | _ => None - end - | MatrixElem m n _ l i j => - match (df_eval_tree σ l) with - | Some l' => let vl' := get_annotation l' in - Some (MatrixElem (vl' i j) l' i j) - | _ => None - end - | VectorDot n _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => let vl' := get_annotation l' in - let vr' := get_annotation r' in - Some (VectorDot (vsum (fun i => (vl' i) * (vr' i))) l' r') - | _, _ => None - end - | VectorSum n _ l => - match df_eval_tree σ l with - | Some l' => let vl' := get_annotation l' in - Some (VectorSum (vsum vl') l') - | _ => None - end - | MatrixSum n m _ l => - match df_eval_tree σ l with - | Some l' => let vl' := get_annotation l' in - Some (MatrixSum (msum vl') l') - | _ => None - end - | VectorScalMult n _ x r => - match df_eval_tree σ x, df_eval_tree σ r with - | Some x', Some r' => let vx' := get_annotation x' in - let vr' := get_annotation r' in - let vec : Vector float n := (fun j => vx' * (vr' j)) in - Some (VectorScalMult vec x' r') - | _, _ => None - end - | MatrixScalMult n m _ x r => - match df_eval_tree σ x, df_eval_tree σ r with - | Some x', Some r' => let vx' := get_annotation x' in - let vr' := get_annotation r' in - let mat : Matrix float n m := fun i j => vx' * (vr' i j) in - Some (MatrixScalMult mat x' r') - | _, _ => None - end - | MatrixVectorMult n m _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => let vl' := get_annotation l' in - let vr' := get_annotation r' in - Some (MatrixVectorMult (matrix_vector_mult vl' vr') l' r') - | _, _ => None - end - | MatrixVectorAdd n m _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => let vl' := get_annotation l' in - let vr' := get_annotation r' in - Some (MatrixVectorAdd (matrix_vector_add vl' vr') l' r') - | _, _ => None - end - | MatrixMult n m p _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => let vl' := get_annotation l' in - let vr' := get_annotation r' in - Some (MatrixMult (matrix_mult vl' vr') l' r') - | _, _ => None - end - | VectorPlus n _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => let vl' := get_annotation l' in - let vr' := get_annotation r' in - let vec : Vector float n := - fun i => (vl' i) + (vr' i) in - Some (VectorPlus vec l' r') - | _, _ => None - end - | VectorMinus n _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => let vl' := get_annotation l' in - let vr' := get_annotation r' in - let vec : Vector float n := - fun i => (vl' i) - (vr' i) in - Some (VectorMinus vec l' r') - | _, _ => None - end - | MatrixPlus n m _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => let vl' := get_annotation l' in - let vr' := get_annotation r' in - let mat : Matrix float n m := - fun i j => (vl' i j) + (vr' i j) in - Some (MatrixPlus mat l' r') - | _, _ => None - end - | MatrixMinus n m _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => let vl' := get_annotation l' in - let vr' := get_annotation r' in - let mat : Matrix float n m := - fun i j => (vl' i j) - (vr' i j) in - Some (MatrixMinus mat l' r') - | _, _ => None - end - - | VectorApply n _ x s r => - match df_eval_tree σ r with - | Some r' => - let vr' := get_annotation r' in - match vectoro_to_ovector - (fun i => - let xv := (x, DTfloat):var_type in - df_eval (cons (mk_env_entry xv (vr' i)) nil) s) with - | Some val => Some (VectorApply val x s r') - | _ => None - end - | _ => None - end - | MatrixApply n m _ x s r => - match df_eval_tree σ r with - | Some r' => - let vr' := get_annotation r' in - match matrixo_to_omatrix - (fun i j => - let xv := (x, DTfloat):var_type in - df_eval (cons (mk_env_entry xv (vr' i j)) nil) s) with - | Some val => Some (MatrixApply val x s r') - | _ => None - end - | _ => None - end - | VLossfun n _ v1 v2 s l r => - match df_eval_tree σ l with - | Some l' => - let vl' := get_annotation l' in - match (vectoro_to_ovector - (fun i => - let xv1 := (v1,DTfloat):var_type in - let xv2 := (v2,DTfloat):var_type in - df_eval (cons (mk_env_entry xv1 (vl' i)) - (cons (mk_env_entry xv2 (r i)) nil)) s)) with - | Some vv => Some (VLossfun (vsum vv) v1 v2 s l' r) - | _ => None - end - | _ => None - end - | MLossfun n m _ v1 v2 s l r => - match df_eval_tree σ l with - | Some l' => - let vl' := get_annotation l' in - match (matrixo_to_omatrix - (fun i j => - let xv1 := (v1,DTfloat):var_type in - let xv2 := (v2,DTfloat):var_type in - df_eval (cons (mk_env_entry xv1 (vl' i j)) - (cons (mk_env_entry xv2 (r i j)) nil)) s)) with - | Some vv => Some (MLossfun ((msum vv)/(FfromZ (Z.of_nat m))) v1 v2 s l' r) - | _ => None - end - | _ => None - end - end. - - Definition eval_env_entry_type := {T:definition_function_types & (DefinedFunction UnitAnn T) & definition_function_types_interp T}. - Definition df_eval_env := list eval_env_entry_type. - - Definition mk_eval_env_entry {T} df val : eval_env_entry_type - := let P := fun t => DefinedFunction UnitAnn t in - let Q := fun t => definition_function_types_interp t in - existT2 P Q T df val. - - Definition pair_update_evals {T} (df:DefinedFunction UnitAnn T) (val:definition_function_types_interp T) (dfevals : df_eval_env) : (definition_function_types_interp T * df_eval_env) := - (val, (mk_eval_env_entry df val)::dfevals). - - Fixpoint df_evals_list {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (dfevals : df_eval_env) : option (definition_function_types_interp T * df_eval_env) - := match df with - | Number _ r => Some (pair_update_evals (Number tt r) r dfevals) - | Constant t _ x => Some (pair_update_evals (Constant tt x) x dfevals) - | DVector n _ dfs => None (*vectoro_to_ovector (fun i => df_eval σ (dfs i))*) - | DMatrix n m _ df => None (*matrixo_to_omatrix (fun i j => df_eval σ (df i j))*) - | Var x _ => - match vartlookup σ x with - | Some val => Some (pair_update_evals (Var x tt) val dfevals) - | _ => None - end - | Plus _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => Some (pair_update_evals (Plus tt l r) (l'+r') dfevals'') - | _ => None - end - | _ => None - end - | Minus _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => Some (pair_update_evals (Minus tt l r) (l'-r') dfevals'') - | _ => None - end - | _ => None - end - | Times _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => Some (pair_update_evals (Times tt l r) (l'*r') dfevals'') - | _ => None - end - | _ => None - end - | Divide _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => Some (pair_update_evals (Divide tt l r) (l'/r') dfevals'') - | _ => None - end - | _ => None - end - | Square _ e => - match df_evals_list σ e dfevals with - | Some (v, dfevals') => Some (pair_update_evals (Square tt e) (v * v) dfevals') - | _ => None - end - | Exp _ e => - match df_evals_list σ e dfevals with - | Some (v, dfevals') => Some (pair_update_evals (Exp tt e) (Fexp v) dfevals') - | _ => None - end - | Log _ e => - match df_evals_list σ e dfevals with - | Some (v, dfevals') => Some (pair_update_evals (Log tt e) (Fln v) dfevals') - | _ => None - end - | Abs _ e => - match df_evals_list σ e dfevals with - | Some (v, dfevals') => Some (pair_update_evals (Abs tt e) (Fabs v) dfevals') - | _ => None - end - | Sign _ e => - match df_evals_list σ e dfevals with - | Some (v, dfevals') => Some (pair_update_evals (Sign tt e) (sign v) dfevals') - | _ => None - end - | PSign _ e => - match df_evals_list σ e dfevals with - | Some (v, dfevals') => Some (pair_update_evals (PSign tt e) (pos_sign v) dfevals') - | _ => None - end - | Max _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => Some (pair_update_evals (Max tt l r) (Fmax l' r') dfevals'') - | _ => None - end - | _ => None - end - | VectorElem n _ l i => - match (df_evals_list σ l dfevals) with - | Some (l', dfevals') => Some (pair_update_evals (VectorElem tt l i) (l' i) dfevals') - | _ => None - end - | MatrixElem m n _ l i j => - match (df_evals_list σ l dfevals) with - | Some (l', dfevals') => Some (pair_update_evals (MatrixElem tt l i j) (l' i j) dfevals') - | _ => None - end - | VectorDot n _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (VectorDot tt l r) - (vsum (fun i => (l' i) * (r' i))) dfevals'') - | _ => None - end - | _ => None - end - | VectorSum n _ l => - match df_evals_list σ l dfevals with - | Some (l',dfevals') => Some (pair_update_evals (VectorSum tt l) (vsum l') dfevals') - | _ => None - end - | MatrixSum n m _ l => - match df_evals_list σ l dfevals with - | Some (l',dfevals') => Some (pair_update_evals (MatrixSum tt l) (msum l') dfevals') - | _ => None - end - | VectorScalMult n _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (VectorScalMult tt l r) (fun j => l' * (r' j)) dfevals'') - | _ => None - end - | _ => None - end - | MatrixScalMult n m _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (MatrixScalMult tt l r) (fun i j => l' * (r' i j)) dfevals'') - | _ => None - end - | _ => None - end - | MatrixVectorMult n m _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (MatrixVectorMult tt l r) - (matrix_vector_mult l' r') dfevals'') - | _ => None - end - | _ => None - end - | MatrixVectorAdd n m _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (MatrixVectorAdd tt l r) - (matrix_vector_add l' r') dfevals'') - | _ => None - end - | _ => None - end - | MatrixMult n m p _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (MatrixMult tt l r) (matrix_mult l' r') dfevals'') - | _ => None - end - | _ => None - end - | VectorPlus n _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (VectorPlus tt l r) - (fun i => (l' i) + (r' i)) dfevals'') - | _ => None - end - | _ => None - end - | VectorMinus n _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (VectorMinus tt l r) - (fun i => (l' i) - (r' i)) dfevals'') - | _ => None - end - | _ => None - end - | MatrixPlus n m _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (MatrixPlus tt l r) - (fun i j => (l' i j) + (r' i j)) dfevals'') - | _ => None - end - | _ => None - end - | MatrixMinus n m _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (MatrixMinus tt l r) - (fun i j => (l' i j) - (r' i j)) dfevals'') - | _ => None - end - | _ => None - end - | VectorApply n _ x s r => - match df_evals_list σ r dfevals with -(* | Some r' => vectoro_to_ovector - (fun i => - let xv := (x, DTfloat):var_type in - df_eval (cons (mk_env_entry xv (r' i)) σ) s) *) - | _ => None - end - | VLossfun n _ v1 v2 s l r => - match df_evals_list σ l dfevals with -(* | Some l' => - match (vectoro_to_ovector - (fun i => - let xv1 := (v1,DTfloat):var_type in - let xv2 := (v2,DTfloat):var_type in - df_eval (cons (mk_env_entry xv1 (l' i)) - (cons (mk_env_entry xv2 (r i)) σ)) s)) with - | Some vv => Some (vsum vv) - | _ => None - end *) - | _ => None - end - | _ => None - end. - -(* - Program - Fixpoint evalslookup {T} (l:df_eval_env) (df:DefinedFunction UnitAnn T) : - option (definition_function_types_interp T) - := match l with - | nil => None - | fv::os => if T == (projT1 (sigT_of_sigT2 fv)) then - if df == (projT2 (sigT_of_sigT2 fv)) then - Some (eq_rect _ definition_function_types_interp (projT3 fv) _ _) - else evalslookup os df - else evalslookup os df - end. -*) - Definition df_eval_symbolic_gradient {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (lv:list var_type) : option (list (definition_function_types_interp T)) - := listo_to_olist (map (df_eval σ) (df_gradient df lv)). - - End eval. - - Section isderiv. - - Context (σ:df_env). - Context (v:SubVar). -(* - Inductive is_deriv : DefinedFunction -> float -> Prop - := - | is_deriv_Number (x : float) : is_deriv (Number x) 0 - | is_deriv_Var_eq : is_deriv (Var v) 1 - | is_deriv_Var_neq (sv : SubVar) : sv <> v -> is_deriv (Var sv) 0 - | is_deriv_Plus l l' r r' : - is_deriv l l' -> - is_deriv r r' -> - is_deriv (Plus l r) (l' + r') - | is_deriv_Minus l l' r r' : - is_deriv l l' -> - is_deriv r r' -> - is_deriv (Minus l r) (l' - r') - | is_deriv_Times l le l' r re r' : - df_eval σ l = Some le -> - is_deriv l l' -> - df_eval σ r = Some re -> - is_deriv r r' -> - is_deriv (Times l r) ((le * r') + (l' * re)) - | is_deriv_Divide l le l' r re r' : - df_eval σ l = Some le -> - is_deriv l l' -> - df_eval σ r = Some re -> - is_deriv r r' -> - is_deriv (Times l r) - (((l' * re ) - (le * r')) - / (re * re)) - | is_deriv_Exp e ee e' : - df_eval σ e = Some ee -> - is_deriv e e' -> - is_deriv (Exp e) (e' * (Fexp ee)) - | is_deriv_Log e ee e' : - df_eval σ e = Some ee -> - is_deriv e e' -> - is_deriv (Exp e) (e' / ee) - | is_deriv_Abs e ee e' : - df_eval σ e = Some ee -> - is_deriv e e' -> is_deriv (Abs e) (e' * (sign ee)) - | is_deriv_Sign (e : DefinedFunction) : - is_deriv (Sign e) 0 - | is_deriv_PSign (e : DefinedFunction) : - is_deriv (PSign e) 0 - | is_deriv_Max_l l le l' re r : - df_eval σ l = Some le -> - df_eval σ r = Some re -> - (le > re) = true -> - is_deriv l l' -> - is_deriv (Max l r) l' - | is_deriv_Max_r l le r re r' : - df_eval σ l = Some le -> - df_eval σ r = Some re -> - (re >= le) = true -> - is_deriv r r' -> - is_deriv (Max l r) r'. - (* - | is_deriv_Max_eq l l' ee r r' : - df_eval σ l = Some ee -> - df_eval σ r = Some ee -> - is_deriv l l' -> - is_deriv r r' -> - is_deriv (Max l r) ((l' + r')/2) *) - -*) - End isderiv. - - Section deriv2. - - Fixpoint df_eval_deriv {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v:var_type) : option (definition_function_types_interp T) - := (match df with - | Number _ _ => Some 0 - | Constant t _ x => Some - match t return definition_function_types_interp t with - | DTfloat => 0 - | DTVector n => ConstVector n 0 - | DTMatrix m n => ConstMatrix m n 0 - end - | DVector n _ dfs => vectoro_to_ovector (fun i => df_eval_deriv σ (dfs i) v) - | DMatrix n m _ df => matrixo_to_omatrix (fun i j => df_eval_deriv σ (df i j) v) - | Var x _ => Some (let t:=snd x in - match t return definition_function_types_interp t with - | DTfloat => if x == v then 1 else 0 - | DTVector n => ConstVector n (if x == v then 1 else 0) - | DTMatrix m n => ConstMatrix m n (if x == v then 1 else 0) - end) - | Plus _ l r => - match df_eval_deriv σ l v, df_eval_deriv σ r v with - | Some le, Some lr => Some (le + lr) - | _, _ => None - end - | Minus _ l r => - match df_eval_deriv σ l v, df_eval_deriv σ r v with - | Some le, Some lr => Some (le - lr) - | _, _ => None - end - | Times _ l r => - match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with - | Some le, Some ld, Some re, Some rd => - Some (le * rd + - (ld * re)) - | _, _, _, _ => None - end - | Divide _ l r => - match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with - | Some le, Some ld, Some re, Some rd => - Some ((ld / re) - ((le * rd) / (re * re))) - | _, _, _, _ => None - end - | Square _ e => - match df_eval σ e, df_eval_deriv σ e v with - | Some ee, Some ed => Some (2 * ee * ed) - | _, _ => None - end - | Exp _ e => - match df_eval σ e, df_eval_deriv σ e v with - | Some ee, Some ed => Some (ed * Fexp ee) - | _, _ => None - end - | Log _ e => - match df_eval σ e, df_eval_deriv σ e v with - | Some ee, Some ed => Some (ed / ee) - | _, _ => None - end - | Abs _ e => - match df_eval σ e, df_eval_deriv σ e v with - | Some ee, Some ed => Some (ed * (sign ee)) - | _, _ => None - end - | Sign _ e => - match df_eval_deriv σ e v with - | Some _ => Some 0 - | None => None - end - | PSign _ e => - match df_eval_deriv σ e v with - | Some _ => Some 0 - | None => None - end - | Max _ l r => - match df_eval σ l, df_eval σ r with - | Some le, Some re => - if le <= re then df_eval_deriv σ r v else df_eval_deriv σ l v - | _, _ => None - end - | VectorElem n _ l i => - match (df_eval_deriv σ l v) with - | Some l' => Some (l' i) - | _ => None - end - | MatrixElem m n _ l i j => - match (df_eval_deriv σ l v) with - | Some l' => Some (l' i j) - | _ => None - end - | VectorDot n _ l r => - match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with - | Some le, Some ld, Some re, Some rd => - Some (vsum (fun j => (le j) * (rd j) + (ld j) * (re j))) - | _, _, _, _ => None - end - | VectorSum n _ l => - match df_eval_deriv σ l v with - | Some ld => - Some (vsum ld) - | _ => None - end - | MatrixSum n m _ l => - match df_eval_deriv σ l v with - | Some ld => - Some (msum ld) - | _ => None - end - | VectorScalMult n _ x r => - match df_eval σ x, df_eval_deriv σ x v, df_eval σ r, df_eval_deriv σ r v with - | Some xe, Some xd, Some re, Some rd => Some (fun j => xe * (rd j) + xd * (re j)) - | _, _, _, _ => None - end - | MatrixScalMult n m _ x r => - match df_eval σ x, df_eval_deriv σ x v, df_eval σ r, df_eval_deriv σ r v with - | Some xe, Some xd, Some re, Some rd => Some (fun i j => xe * (rd i j) + xd * (re i j)) - | _, _, _, _ => None - end - | MatrixVectorMult n m _ l r => - match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with - | Some le, Some ld, Some re, Some rd => - Some (fun i => vsum (fun j => (le i j)*(rd j) + (ld i j)*(re j))) - | _, _, _, _ => None - end - | MatrixVectorAdd n m _ l r => - match df_eval_deriv σ l v, df_eval_deriv σ r v with - | Some le, Some re => - Some (fun i j => (le i j) + (re i)) - | _, _ => None - end - | MatrixMult n m p _ l r => - match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with - | Some le, Some ld, Some re, Some rd => - Some (fun i k => vsum (fun j => (le i j)*(rd j k) + (ld i j)*(re j k))) - | _, _, _, _ => None - end - | VectorPlus n _ l r => - match df_eval_deriv σ l v, df_eval_deriv σ r v with - | Some l', Some r' => Some (fun i => (l' i) + (r' i)) - | _, _ => None - end - | VectorMinus n _ l r => - match df_eval_deriv σ l v, df_eval_deriv σ r v with - | Some l', Some r' => Some (fun i => (l' i) - (r' i)) - | _, _ => None - end - | MatrixPlus n m _ l r => - match df_eval_deriv σ l v, df_eval_deriv σ r v with - | Some l', Some r' => Some (fun i j => (l' i j) + (r' i j)) - | _, _ => None - end - | MatrixMinus n m _ l r => - match df_eval_deriv σ l v, df_eval_deriv σ r v with - | Some l', Some r' => Some (fun i j => (l' i j) - (r' i j)) - | _, _ => None - end - | VectorApply n _ x s r => - match df_eval σ r, df_eval_deriv σ r v with - | Some re, Some rd => - vectoro_to_ovector - (fun i => - let xv := (x, DTfloat):var_type in - match df_eval_deriv (cons (mk_env_entry xv (re i)) nil) s xv with - | Some sd => Some ((rd i) * sd) - | _ => None - end) - | _, _ => None - end - | MatrixApply n m _ x s r => - match df_eval σ r, df_eval_deriv σ r v with - | Some re, Some rd => - matrixo_to_omatrix - (fun i j => - let xv := (x, DTfloat):var_type in - match df_eval_deriv (cons (mk_env_entry xv (re i j)) nil) s xv with - | Some sd => Some ((rd i j) * sd) - | _ => None - end) - | _, _ => None - end - | VLossfun n _ v1 v2 s l r => - match df_eval σ l, df_eval_deriv σ l v with - | Some le, Some ld => - match (vectoro_to_ovector - (fun i => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - match df_eval_deriv (cons (mk_env_entry xv1 (le i)) - (cons (mk_env_entry xv2 (r i)) nil)) s xv1 with - | Some sd => Some ((ld i) * sd) - | _ => None - end)) with - | Some vv => Some (vsum vv) - | _ => None - end - | _, _ => None - end - | MLossfun n m _ v1 v2 s l r => - match df_eval σ l, df_eval_deriv σ l v with - | Some le, Some ld => - match (matrixo_to_omatrix - (fun i j => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - match df_eval_deriv (cons (mk_env_entry xv1 (le i j)) - (cons (mk_env_entry xv2 (r i j)) nil)) s xv1 with - | Some sd => Some ((ld i j) * sd) - | _ => None - end)) with - | Some vv => Some ((msum vv)/(FfromZ (Z.of_nat m))) - | _ => None - end - | _, _ => None - end - end). - - Definition mk_genvar_env (s:SubVar) := mk_env_entry (s, DTfloat) (FfromZ (Z.of_nat 1)) :: nil. - - (* the v environment below pairs variables with their derivatives *) - (* in some sense this is giving a directional derivative defined by v *) - Fixpoint df_eval_deriv_genvar {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v:df_env) : option (definition_function_types_interp T) - := (match df with - | Number _ _ => Some 0 - | Constant t _ x => Some - match t return definition_function_types_interp t with - | DTfloat => 0 - | DTVector n => ConstVector n 0 - | DTMatrix m n => ConstMatrix m n 0 - end - | DVector n _ dfs => vectoro_to_ovector (fun i => df_eval_deriv_genvar σ (dfs i) v) - | DMatrix n m _ df => matrixo_to_omatrix (fun i j => df_eval_deriv_genvar σ (df i j) v) - | Var x _ => Some ( - match vartlookup v x with - | Some val => val - | _ => - match (snd x) with - | DTfloat => 0 - | DTVector n => ConstVector n 0 - | DTMatrix m n => ConstMatrix m n 0 - end - end) - | Plus _ l r => - match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with - | Some le, Some lr => Some (le + lr) - | _, _ => None - end - | Minus _ l r => - match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with - | Some le, Some lr => Some (le - lr) - | _, _ => None - end - | Times _ l r => - match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with - | Some le, Some ld, Some re, Some rd => - Some (le * rd + - (ld * re)) - | _, _, _, _ => None - end - | Divide _ l r => - match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with - | Some le, Some ld, Some re, Some rd => - Some ((ld / re) - ((le * rd) / (re * re))) - | _, _, _, _ => None - end - | Square _ e => - match df_eval σ e, df_eval_deriv_genvar σ e v with - | Some ee, Some ed => Some (2 * ee * ed) - | _, _ => None - end - | Exp _ e => - match df_eval σ e, df_eval_deriv_genvar σ e v with - | Some ee, Some ed => Some (ed * Fexp ee) - | _, _ => None - end - | Log _ e => - match df_eval σ e, df_eval_deriv_genvar σ e v with - | Some ee, Some ed => Some (ed / ee) - | _, _ => None - end - | Abs _ e => - match df_eval σ e, df_eval_deriv_genvar σ e v with - | Some ee, Some ed => Some (ed * (sign ee)) - | _, _ => None - end - | Sign _ e => - match df_eval_deriv_genvar σ e v with - | Some _ => Some 0 - | None => None - end - | PSign _ e => - match df_eval_deriv_genvar σ e v with - | Some _ => Some 0 - | None => None - end - | Max _ l r => - match df_eval σ l, df_eval σ r with - | Some le, Some re => - if le <= re then df_eval_deriv_genvar σ r v else df_eval_deriv_genvar σ l v - | _, _ => None - end - | VectorElem n _ l i => - match (df_eval_deriv_genvar σ l v) with - | Some l' => Some (l' i) - | _ => None - end - | MatrixElem m n _ l i j => - match (df_eval_deriv_genvar σ l v) with - | Some l' => Some (l' i j) - | _ => None - end - | VectorDot n _ l r => - match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with - | Some le, Some ld, Some re, Some rd => - Some (vsum (fun j => (le j) * (rd j) + (ld j) * (re j))) - | _, _, _, _ => None - end - | VectorSum n _ l => - match df_eval_deriv_genvar σ l v with - | Some ld => - Some (vsum ld) - | _ => None - end - | MatrixSum n m _ l => - match df_eval_deriv_genvar σ l v with - | Some ld => - Some (msum ld) - | _ => None - end - | VectorScalMult n _ x r => - match df_eval σ x, df_eval_deriv_genvar σ x v, df_eval σ r, df_eval_deriv_genvar σ r v with - | Some xe, Some xd, Some re, Some rd => Some (fun j => xe * (rd j) + xd * (re j)) - | _, _, _, _ => None - end - | MatrixScalMult n m _ x r => - match df_eval σ x, df_eval_deriv_genvar σ x v, df_eval σ r, df_eval_deriv_genvar σ r v with - | Some xe, Some xd, Some re, Some rd => Some (fun i j => xe * (rd i j) + xd * (re i j)) - | _, _, _, _ => None - end - | MatrixVectorMult n m _ l r => - match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with - | Some le, Some ld, Some re, Some rd => - Some (fun i => vsum (fun j => (le i j)*(rd j) + (ld i j)*(re j))) - | _, _, _, _ => None - end - | MatrixVectorAdd n m _ l r => - match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with - | Some le, Some re => - Some (fun i j => (le i j) + (re i)) - | _, _ => None - end - | MatrixMult n m p _ l r => - match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with - | Some le, Some ld, Some re, Some rd => - Some (fun i k => vsum (fun j => (le i j)*(rd j k) + (ld i j)*(re j k))) - | _, _, _, _ => None - end - | VectorPlus n _ l r => - match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with - | Some l', Some r' => Some (fun i => (l' i) + (r' i)) - | _, _ => None - end - | VectorMinus n _ l r => - match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with - | Some l', Some r' => Some (fun i => (l' i) - (r' i)) - | _, _ => None - end - | MatrixPlus n m _ l r => - match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with - | Some l', Some r' => Some (fun i j => (l' i j) + (r' i j)) - | _, _ => None - end - | MatrixMinus n m _ l r => - match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with - | Some l', Some r' => Some (fun i j => (l' i j) - (r' i j)) - | _, _ => None - end - | VectorApply n _ x s r => - match df_eval σ r, df_eval_deriv_genvar σ r v with - | Some re, Some rd => - vectoro_to_ovector - (fun i => - let xv := (x, DTfloat):var_type in - match df_eval_deriv_genvar (mk_env_entry xv (re i) :: nil) s - (mk_genvar_env x) with - | Some sd => Some ((rd i) * sd) - | _ => None - end) - | _, _ => None - end - | MatrixApply n m _ x s r => - match df_eval σ r, df_eval_deriv_genvar σ r v with - | Some re, Some rd => - matrixo_to_omatrix - (fun i j => - let xv := (x, DTfloat):var_type in - match df_eval_deriv_genvar (mk_env_entry xv (re i j) :: nil) s - (mk_genvar_env x) with - | Some sd => Some ((rd i j) * sd) - | _ => None - end) - | _, _ => None - end - | VLossfun n _ v1 v2 s l r => - match df_eval σ l, df_eval_deriv_genvar σ l v with - | Some le, Some ld => - match (vectoro_to_ovector - (fun i => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - match df_eval_deriv_genvar ( mk_env_entry xv1 (le i) :: - mk_env_entry xv2 (r i) :: nil) s - (mk_genvar_env v1) with - | Some sd => Some ((ld i) * sd) - | _ => None - end)) with - | Some vv => Some (vsum vv) - | _ => None - end - | _, _ => None - end - | MLossfun n m _ v1 v2 s l r => - match df_eval σ l, df_eval_deriv_genvar σ l v with - | Some le, Some ld => - match (matrixo_to_omatrix - (fun i j => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - match df_eval_deriv_genvar ( mk_env_entry xv1 (le i j) :: - mk_env_entry xv2 (r i j) :: nil) s - ( mk_genvar_env v1) with - | Some sd => Some ((ld i j) * sd) - | _ => None - end)) with - | Some vv => Some ((msum vv)/(FfromZ (Z.of_nat m))) - | _ => None - end - | _, _ => None - end - end). - - - Definition definition_function_types_interp_prod (vart dft:definition_function_types) : Type - := match vart with - | DTfloat => definition_function_types_interp dft - | DTVector n => Vector (definition_function_types_interp dft) n - | DTMatrix m n => Matrix (definition_function_types_interp dft) m n - end. - - - Definition UnitVector (n:nat) (j : {n':nat | (n' < n)%nat}) : Vector float n := - fun i => if (proj1_sig i) == (proj1_sig j) then 1 else 0. - - Definition UnitMatrix (n m: nat) - (i : {n':nat | (n' < n)%nat}) - (j : {m':nat | (m' < m)%nat}) : Matrix float n m := - fun a b => if (proj1_sig a) == (proj1_sig i) then - (if (proj1_sig b) == (proj1_sig j) then 1 else 0) - else 0. - - Definition const_env (v : var_type) : df_env - := match (snd v) with - | DTfloat => ((mk_env_entry (fst v, DTfloat) 0)::nil) - | DTVector n => ((mk_env_entry (fst v, DTVector n) (ConstVector n 0))::nil) - | DTMatrix m n => ((mk_env_entry (fst v, DTMatrix m n) (ConstMatrix m n 0))::nil) - end. - - Fixpoint df_eval_tree_deriv {T} (σ:df_env) (df:DefinedFunction EvalAnn T) (v:var_type) : option (definition_function_types_interp T) - := (match df with - | Number _ _ => Some 0 - | Constant t _ x => Some - match t return definition_function_types_interp t with - | DTfloat => 0 - | DTVector n => ConstVector n 0 - | DTMatrix m n => ConstMatrix m n 0 - end - | DVector n _ dfs => vectoro_to_ovector (fun i => df_eval_tree_deriv σ (dfs i) v) - | DMatrix n m _ df => matrixo_to_omatrix (fun i j => df_eval_tree_deriv σ (df i j) v) - | Var x _ => Some (let t:=snd x in - match t return definition_function_types_interp t with - | DTfloat => if x == v then 1 else 0 - | DTVector n => ConstVector n (if x == v then 1 else 0) - | DTMatrix m n => ConstMatrix m n (if x == v then 1 else 0) - end) - | Plus _ l r => - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some le, Some lr => Some (le + lr) - | _, _ => None - end - | Minus _ l r => - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some le, Some lr => Some (le - lr) - | _, _ => None - end - | Times _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some ld, Some rd => - Some (le * rd + - (ld * re)) - | _, _ => None - end - | Divide _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some ld, Some rd => - Some ((ld / re) - ((le * rd) / (re * re))) - | _, _ => None - end - | Square _ e => - let ee := get_annotation e in - match df_eval_tree_deriv σ e v with - | Some ed => Some (2 * ee * ed) - | _ => None - end - | Exp _ e => - let ee := get_annotation e in - match df_eval_tree_deriv σ e v with - | Some ed => Some (ed * Fexp ee) - | _ => None - end - | Log _ e => - let ee := get_annotation e in - match df_eval_tree_deriv σ e v with - | Some ed => Some (ed / ee) - | _ => None - end - | Abs _ e => - let ee := get_annotation e in - match df_eval_tree_deriv σ e v with - | Some ed => Some (ed * (sign ee)) - | _ => None - end - | Sign _ e => - match df_eval_tree_deriv σ e v with - | Some _ => Some 0 - | None => None - end - | PSign _ e => - match df_eval_tree_deriv σ e v with - | Some _ => Some 0 - | None => None - end - | Max _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - if le <= re then df_eval_tree_deriv σ r v else df_eval_tree_deriv σ l v - | VectorElem n _ l i => - match (df_eval_tree_deriv σ l v) with - | Some l' => Some (l' i) - | _ => None - end - | MatrixElem m n _ l i j => - match (df_eval_tree_deriv σ l v) with - | Some l' => Some (l' i j) - | _ => None - end - | VectorDot n _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some ld, Some rd => - Some (vsum (fun j => (le j) * (rd j) + (ld j) * (re j))) - | _, _ => None - end - | VectorSum n _ l => - match df_eval_tree_deriv σ l v with - | Some ld => - Some (vsum ld) - | _ => None - end - | MatrixSum n m _ l => - match df_eval_tree_deriv σ l v with - | Some ld => - Some (msum ld) - | _ => None - end - | VectorScalMult n _ x r => - let '(xe,re) := (get_annotation x, get_annotation r) in - match df_eval_tree_deriv σ x v, df_eval_tree_deriv σ r v with - | Some xd, Some rd => Some (fun j => xe * (rd j) + xd * (re j)) - | _, _ => None - end - | MatrixScalMult n m _ x r => - let '(xe,re) := (get_annotation x, get_annotation r) in - match df_eval_tree_deriv σ x v, df_eval_tree_deriv σ r v with - | Some xd, Some rd => Some (fun i j => xe * (rd i j) + xd * (re i j)) - | _, _ => None - end - | MatrixVectorMult n m _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some ld, Some rd => - Some (fun i => vsum (fun j => (le i j)*(rd j) + (ld i j)*(re j))) - | _, _ => None - end - | MatrixVectorAdd n m _ l r => - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some ld, Some rd => - Some (fun i j => (ld i j) + (rd i)) - | _, _ => None - end - | MatrixMult n m p _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some ld, Some rd => - Some (fun i k => vsum (fun j => (le i j)*(rd j k) + (ld i j)*(re j k))) - | _, _ => None - end - | VectorPlus n _ l r => - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some l', Some r' => Some (fun i => (l' i) + (r' i)) - | _, _ => None - end - | VectorMinus n _ l r => - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some l', Some r' => Some (fun i => (l' i) - (r' i)) - | _, _ => None - end - | MatrixPlus n m _ l r => - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some l', Some r' => Some (fun i j => (l' i j) + (r' i j)) - | _, _ => None - end - | MatrixMinus n m _ l r => - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some l', Some r' => Some (fun i j => (l' i j) - (r' i j)) - | _, _ => None - end - | VectorApply n _ x s r => - let re := get_annotation r in - match df_eval_tree_deriv σ r v with - | Some rd => - vectoro_to_ovector - (fun i => - let xv := (x, DTfloat):var_type in - match df_eval_deriv (cons (mk_env_entry xv (re i)) nil ) s xv with - | Some sd => Some ((rd i) * sd) - | _ => None - end) - | _ => None - end - | MatrixApply n m _ x s r => - let re := get_annotation r in - match df_eval_tree_deriv σ r v with - | Some rd => - matrixo_to_omatrix - (fun i j => - let xv := (x, DTfloat):var_type in - match df_eval_deriv (cons (mk_env_entry xv (re i j)) nil ) s xv with - | Some sd => Some ((rd i j) * sd) - | _ => None - end) - | _ => None - end - | VLossfun n _ v1 v2 s l r => - let le := get_annotation l in - match df_eval_tree_deriv σ l v with - | Some ld => - match (vectoro_to_ovector - (fun i => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - match df_eval_deriv (cons (mk_env_entry xv1 (le i)) - (cons (mk_env_entry xv2 (r i)) nil)) s xv1 with - | Some sd => Some ((ld i) * sd) - | _ => None - end)) with - | Some vv => Some (vsum vv) - | _ => None - end - | _ => None - end - | MLossfun n m _ v1 v2 s l r => - let le := get_annotation l in - match df_eval_tree_deriv σ l v with - | Some ld => - match (matrixo_to_omatrix - (fun i j => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - match df_eval_deriv (cons (mk_env_entry xv1 (le i j)) - (cons (mk_env_entry xv2 (r i j)) nil )) s xv1 with - | Some sd => Some ((ld i j) * sd) - | _ => None - end)) with - | Some vv => Some ((msum vv) / (FfromZ (Z.of_nat m))) - | _ => None - end - | _ => None - end - end). - - Fixpoint df_eval_tree_deriv_genvar {T} (σ:df_env) (df:DefinedFunction EvalAnn T) (v:df_env) : option (definition_function_types_interp T) - := (match df with - | Number _ _ => Some 0 - | Constant t _ x => Some - match t return definition_function_types_interp t with - | DTfloat => 0 - | DTVector n => ConstVector n 0 - | DTMatrix m n => ConstMatrix m n 0 - end - | DVector n _ dfs => vectoro_to_ovector (fun i => df_eval_tree_deriv_genvar σ (dfs i) v) - | DMatrix n m _ df => matrixo_to_omatrix (fun i j => df_eval_tree_deriv_genvar σ (df i j) v) - | Var x _ => Some ( - match vartlookup v x with - | Some val => val - | _ => - match (snd x) with - | DTfloat => 0 - | DTVector n => ConstVector n 0 - | DTMatrix m n => ConstMatrix m n 0 - end - end) - | Plus _ l r => - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some le, Some lr => Some (le + lr) - | _, _ => None - end - | Minus _ l r => - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some le, Some lr => Some (le - lr) - | _, _ => None - end - | Times _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some ld, Some rd => - Some (le * rd + - (ld * re)) - | _, _ => None - end - | Divide _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some ld, Some rd => - Some ((ld / re) - ((le * rd) / (re * re))) - | _, _ => None - end - | Square _ e => - let ee := get_annotation e in - match df_eval_tree_deriv_genvar σ e v with - | Some ed => Some (2 * ee * ed) - | _ => None - end - | Exp _ e => - let ee := get_annotation e in - match df_eval_tree_deriv_genvar σ e v with - | Some ed => Some (ed * Fexp ee) - | _ => None - end - | Log _ e => - let ee := get_annotation e in - match df_eval_tree_deriv_genvar σ e v with - | Some ed => Some (ed / ee) - | _ => None - end - | Abs _ e => - let ee := get_annotation e in - match df_eval_tree_deriv_genvar σ e v with - | Some ed => Some (ed * (sign ee)) - | _ => None - end - | Sign _ e => - match df_eval_tree_deriv_genvar σ e v with - | Some _ => Some 0 - | None => None - end - | PSign _ e => - match df_eval_tree_deriv_genvar σ e v with - | Some _ => Some 0 - | None => None - end - | Max _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - if le <= re then df_eval_tree_deriv_genvar σ r v else df_eval_tree_deriv_genvar σ l v - | VectorElem n _ l i => - match (df_eval_tree_deriv_genvar σ l v) with - | Some l' => Some (l' i) - | _ => None - end - | MatrixElem m n _ l i j => - match (df_eval_tree_deriv_genvar σ l v) with - | Some l' => Some (l' i j) - | _ => None - end - | VectorDot n _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some ld, Some rd => - Some (vsum (fun j => (le j) * (rd j) + (ld j) * (re j))) - | _, _ => None - end - | VectorSum n _ l => - match df_eval_tree_deriv_genvar σ l v with - | Some ld => - Some (vsum ld) - | _ => None - end - | MatrixSum n m _ l => - match df_eval_tree_deriv_genvar σ l v with - | Some ld => - Some (msum ld) - | _ => None - end - | VectorScalMult n _ x r => - let '(xe,re) := (get_annotation x, get_annotation r) in - match df_eval_tree_deriv_genvar σ x v, df_eval_tree_deriv_genvar σ r v with - | Some xd, Some rd => Some (fun j => xe * (rd j) + xd * (re j)) - | _, _ => None - end - | MatrixScalMult n m _ x r => - let '(xe,re) := (get_annotation x, get_annotation r) in - match df_eval_tree_deriv_genvar σ x v, df_eval_tree_deriv_genvar σ r v with - | Some xd, Some rd => Some (fun i j => xe * (rd i j) + xd * (re i j)) - | _, _ => None - end - | MatrixVectorMult n m _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some ld, Some rd => - Some (fun i => vsum (fun j => (le i j)*(rd j) + (ld i j)*(re j))) - | _, _ => None - end - | MatrixVectorAdd n m _ l r => - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some ld, Some rd => - Some (fun i j => (ld i j) + (rd i)) - | _, _ => None - end - | MatrixMult n m p _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some ld, Some rd => - Some (fun i k => vsum (fun j => (le i j)*(rd j k) + (ld i j)*(re j k))) - | _, _ => None - end - | VectorPlus n _ l r => - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some l', Some r' => Some (fun i => (l' i) + (r' i)) - | _, _ => None - end - | VectorMinus n _ l r => - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some l', Some r' => Some (fun i => (l' i) - (r' i)) - | _, _ => None - end - | MatrixPlus n m _ l r => - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some l', Some r' => Some (fun i j => (l' i j) + (r' i j)) - | _, _ => None - end - | MatrixMinus n m _ l r => - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some l', Some r' => Some (fun i j => (l' i j) - (r' i j)) - | _, _ => None - end - | VectorApply n _ x s r => - let re := get_annotation r in - match df_eval_tree_deriv_genvar σ r v with - | Some rd => - vectoro_to_ovector - (fun i => - let xv := (x, DTfloat):var_type in - match df_eval_deriv_genvar (cons (mk_env_entry xv (re i)) nil ) s - (mk_genvar_env x) with - | Some sd => Some ((rd i) * sd) - | _ => None - end) - | _ => None - end - | MatrixApply n m _ x s r => - let re := get_annotation r in - match df_eval_tree_deriv_genvar σ r v with - | Some rd => - matrixo_to_omatrix - (fun i j => - let xv := (x, DTfloat):var_type in - match df_eval_deriv_genvar (cons (mk_env_entry xv (re i j)) nil) s - (mk_genvar_env x) with - | Some sd => Some ((rd i j) * sd) - | _ => None - end) - | _ => None - end - | VLossfun n _ v1 v2 s l r => - let le := get_annotation l in - match df_eval_tree_deriv_genvar σ l v with - | Some ld => - match (vectoro_to_ovector - (fun i => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - match df_eval_deriv_genvar (cons (mk_env_entry xv1 (le i)) - (cons (mk_env_entry xv2 (r i)) nil )) s - (mk_genvar_env v1) with - | Some sd => Some ((ld i) * sd) - | _ => None - end)) with - | Some vv => Some (vsum vv) - | _ => None - end - | _ => None - end - | MLossfun n m _ v1 v2 s l r => - let le := get_annotation l in - match df_eval_tree_deriv_genvar σ l v with - | Some ld => - match (matrixo_to_omatrix - (fun i j => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - match df_eval_deriv_genvar (cons (mk_env_entry xv1 (le i j)) - (cons (mk_env_entry xv2 (r i j)) nil)) s - (mk_genvar_env v1) with - | Some sd => Some ((ld i j) * sd) - | _ => None - end)) with - | Some vv => Some ((msum vv) / (FfromZ (Z.of_nat m))) - | _ => None - end - | _ => None - end - end). - - Definition vector_env_iter {n} {A} (f: A -> df_env -> option df_env) - (env: df_env) (v : Vector A n) : option df_env := - vector_fold_right (fun a oenv => match oenv with - | Some env => f a env - | _ => None - end) - (Some env) v. - - Fixpoint list_env_iter {A} (f: A -> df_env -> option df_env) - (oenv:option df_env) (l: list A) : option df_env := - match oenv, l with - | Some env, x :: l' => list_env_iter f (f x env) l' - | _, _ => oenv - end. - - Lemma list_env_iter_none {A} (f: A -> df_env -> option df_env) (l: list A) : - list_env_iter f None l = None. - Proof. - induction l. - now simpl. - now simpl. - Qed. - - Lemma list_env_iter_env_not_none {A} (f: A -> df_env -> option df_env) - (oenv : option df_env) (l: list A): - list_env_iter f oenv l <> None -> oenv <> None. - Proof. - intros. - destruct oenv. - + discriminate. - + rewrite list_env_iter_none in H. - tauto. - Qed. - - Lemma list_env_iter_app {A} (f: A -> df_env -> option df_env) - (oenv:option df_env) (l1 l2: list A) : - list_env_iter f oenv (l1++l2) = - list_env_iter f (list_env_iter f oenv l1) l2. - Proof. - revert l2 oenv. - induction l1; intros l2 oenv; simpl. - - now destruct oenv. - - destruct oenv. - + auto. - + now rewrite list_env_iter_none. - Qed. - - - Lemma list_env_iter_ext {A} f1 f2 oenv (l:list A) : - (forall x a, In x l -> f1 x a = f2 x a) -> - list_env_iter f1 oenv l = list_env_iter f2 oenv l. - Proof. - intros fa. - revert oenv. - induction l; intros oenv; intros; simpl - ; match_destr. - rewrite fa; simpl; intuition. - Qed. - - - Definition two_vector_env_iter {n} {A B} (f: A -> B -> df_env -> option df_env) - (env: df_env) (v: Vector A n) (w: Vector B n) : option df_env := - vector_env_iter (fun '(a,b) env => f a b env) env - (vector_zip v w). - - - Definition two_vector_env_iter_alt {n} {A B} (f: A -> B -> df_env -> option df_env) - (env: df_env) (v: Vector A n) (w: Vector B n) : option df_env := - list_env_iter (fun i env => f (v i) (w i) env) (Some env) (bounded_seq0 n). - - Definition matrix_env_iter {m n} {A} (f: A -> df_env -> option df_env) - (env: option df_env) (mat : Matrix A m n) : option df_env := - vector_fold_right - (fun vec oenv => - vector_fold_right (fun a oenv => match oenv with - | Some env => f a env - | _ => None - end) oenv vec - ) env mat. - - Definition two_matrix_env_iter {n m} {A B} (f: A -> B -> df_env -> option df_env) - (env: option df_env) (v: Matrix A n m) (w: Matrix B n m) : option df_env := - let vw := matrix_zip v w in - matrix_env_iter (fun '(a,b) e => f a b e) env vw. - - Definition two_matrix_env_iter_alt {n m} {A B} (f: A -> B -> df_env -> option df_env) - (env: df_env) (v: Matrix A n m) (w: Matrix B n m) : option df_env := - list_env_iter (fun i env => list_env_iter (fun j env => f (v i j) (w i j) env) - (Some env) (bounded_seq0 m)) - (Some env) (bounded_seq0 n). - - - Program Definition addvar (x : var_type) (grad_env:df_env) := - (match snd x as y return snd x = y -> - definition_function_types_interp y -> - definition_function_types_interp y with - | DTfloat => fun pf grad => match vartlookup grad_env x with - | Some val => grad + ((coerce _ val):float) - | _ => grad - end - | DTVector n => fun pf grad => match vartlookup grad_env x with - | Some val => fun i => (grad i) + (((coerce _ val):Vector float n) i) - | _ => grad - end - | DTMatrix m n => fun pf grad => match vartlookup grad_env x with - | Some val => fun i j => (((coerce _ val):Matrix float m n) i j) + (grad i j) - | _ => grad - end - end) (eq_refl _). - Next Obligation. - rewrite pf; reflexivity. - Defined. - Next Obligation. - rewrite pf; reflexivity. - Defined. - Next Obligation. - rewrite pf; reflexivity. - Defined. - - Definition gradenv_init1 (v : var_type) : env_entry_type := - mk_env_entry v - (match snd v as y return definition_function_types_interp y with - | DTfloat => 0 - | DTVector n => ConstVector n 0 - | DTMatrix n m => ConstMatrix n m 0 - end). - - Definition gradenv_init (dvars : list var_type) : df_env := - map gradenv_init1 dvars. - - Fixpoint df_eval_backprop_deriv {T Ann} (σ:df_env) (df:DefinedFunction Ann T) (grad_env:df_env) {struct df} : definition_function_types_interp T -> option df_env - := match df with - | Number _ _ => fun grad => Some grad_env - | Constant _ _ _ => fun grad => Some grad_env - | DVector n _ dfs => fun grad => - two_vector_env_iter_alt (fun x g genv => df_eval_backprop_deriv σ x genv g) - grad_env dfs grad - | DMatrix n m _ dfs => fun grad => - two_matrix_env_iter_alt (fun x g genv => df_eval_backprop_deriv σ x genv g) - grad_env dfs grad - | Var x _ => fun grad => - if vartlookup grad_env x then - Some (vart_update grad_env x (addvar x grad_env grad)) - else Some grad_env - | Plus _ l r => fun grad => - match df_eval_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_backprop_deriv σ r grad_env' grad - | _ => None - end - | Minus _ l r => fun grad => - match df_eval_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (-grad) - | _ => None - end - | Times _ l r => fun grad => - match df_eval σ l, df_eval σ r with - | Some le, Some re => - match df_eval_backprop_deriv σ l grad_env (re * grad) with - | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (le * grad) - | _ => None - end - | _, _ => None - end - | Divide _ l r => fun grad => - match df_eval σ l, df_eval σ r with - | Some le, Some re => - match df_eval_backprop_deriv σ l grad_env (grad / re) with - | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (- le / (re * re) * grad) - | _ => None - end - | _, _ => None - end - | Square _ e => fun grad => - match df_eval σ e with - | Some ee => df_eval_backprop_deriv σ e grad_env (2 * ee * grad) - | _ => None - end - | Exp _ e => fun grad => - match df_eval σ e with - | Some ee => df_eval_backprop_deriv σ e grad_env (grad * Fexp ee) - | _ => None - end - | Log _ e => fun grad => - match df_eval σ e with - | Some ee => df_eval_backprop_deriv σ e grad_env (grad / ee) - | _ => None - end - | Abs _ e => fun grad => - match df_eval σ e with - | Some ee => df_eval_backprop_deriv σ e grad_env (grad * (sign ee)) - | _ => None - end - | Sign _ e => fun grad => df_eval_backprop_deriv σ e grad_env 0 - | PSign _ e => fun grad => df_eval_backprop_deriv σ e grad_env 0 - | Max _ l r => fun grad => - match df_eval σ l, - df_eval σ r with - | Some le, Some re => - if le <= re then - (df_eval_backprop_deriv σ r grad_env grad) else - (df_eval_backprop_deriv σ l grad_env grad) - | _, _ => None - end - | VectorElem n _ l i => fun grad => - let grad' := fun k => if proj1_sig k == proj1_sig i then grad else 0 in - df_eval_backprop_deriv σ l grad_env grad' - | MatrixElem m n _ l i j => fun grad => - let grad' := fun k1 k2 => - if (proj1_sig k1 == proj1_sig i) then - if (proj1_sig k2 == proj1_sig j) then grad else 0 - else 0 in - df_eval_backprop_deriv σ l grad_env grad' - | VectorDot n _ l r => fun grad => - match df_eval σ l, df_eval σ r with - | Some le, Some re => - match df_eval_backprop_deriv σ l grad_env (vmap (fun rv => rv*grad) re) with - | Some grad_env' => - df_eval_backprop_deriv σ r grad_env' (vmap (fun lv => lv*grad) le) - | _ => None - end - | _, _ => None - end - | VectorSum n _ l => fun grad => - df_eval_backprop_deriv σ l grad_env (ConstVector n grad) - | MatrixSum n m _ l => fun grad => - df_eval_backprop_deriv σ l grad_env (ConstMatrix n m grad) - | VectorScalMult n _ x r => fun grad => - match df_eval σ x, df_eval σ r with - | Some xe, Some re => - match df_eval_backprop_deriv σ x grad_env (vsum (fun j => (re j) * (grad j))) with - | Some grad_env' => - df_eval_backprop_deriv σ r grad_env' (fun j => xe * (grad j)) - | _ => None - end - | _, _ => None - end - | MatrixScalMult n m _ x r => fun grad => - match df_eval σ x, df_eval σ r with - | Some xe, Some re => - match df_eval_backprop_deriv σ x grad_env (msum (fun i j => (re i j) * (grad i j))) with - | Some grad_env' => - df_eval_backprop_deriv σ r grad_env' (fun i j => (grad i j) * xe) - | _ => None - end - | _, _ => None - end - | MatrixVectorMult n m _ l r => fun grad => - match df_eval σ l, df_eval σ r with - | Some le, Some re => - match df_eval_backprop_deriv σ l grad_env (fun i j => (grad i) * (re j)) with - - - | Some grad_env' => - df_eval_backprop_deriv σ r grad_env' (matrix_vector_mult (fun i j => le j i) grad) - | _ => None - end - | _, _ => None - end - | MatrixVectorAdd n m _ l r => fun grad => - match df_eval_backprop_deriv σ l grad_env grad with - | Some grad_env' => - match list_env_iter - (fun i env => df_eval_backprop_deriv σ r env ((transpose grad) i)) - (Some grad_env') (bounded_seq0 m) with - | Some grad_env'' => Some grad_env'' - | _ => None - end - | _ => None - end - | MatrixMult n m p _ l r => fun grad => - match df_eval σ l, df_eval σ r with - | Some le, Some re => - match df_eval_backprop_deriv σ l grad_env (matrix_mult grad (fun i j => (re j i))) with - - - | Some grad_env' => - df_eval_backprop_deriv σ r grad_env' (matrix_mult (fun i j => le j i) grad) - | _ => None - end - | _, _ => None - end - | VectorPlus n _ l r => fun grad => - match df_eval_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_backprop_deriv σ r grad_env' grad - | _ => None - end - | VectorMinus n _ l r => fun grad => - match df_eval_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (fun i => - (grad i)) - | _ => None - end - | MatrixPlus n m _ l r => fun grad => - match df_eval_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_backprop_deriv σ r grad_env' grad - | _ => None - end - | MatrixMinus n m _ l r => fun grad => - match df_eval_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (fun i j => - (grad i j)) - | _ => None - end - | VectorApply n _ x s r => fun grad => - match df_eval σ r with - | Some re => - let xv := (x, DTfloat):var_type in - let s' := df_deriv s xv in - let ograd := - vmap (fun '(rei, g) => - - match df_eval (cons (mk_env_entry xv rei) nil) s' with - | Some se => Some (g * se) - | _ => None - end) - (vector_zip re grad) in - match vectoro_to_ovector ograd with - | Some grad' => df_eval_backprop_deriv σ r grad_env grad' - | _ => None - end - | _ => None - end - | MatrixApply n m _ x s r => fun grad => - match df_eval σ r with - | Some re => - let xv := (x, DTfloat):var_type in - let s' := df_deriv s xv in - let ograd := - mmap (fun '(rei, g) => - match df_eval (cons (mk_env_entry xv rei) nil) s' with - | Some se => Some (g * se) - | _ => None - end) - (matrix_zip re grad) in - match matrixo_to_omatrix ograd with - | Some grad' => df_eval_backprop_deriv σ r grad_env grad' - | _ => None - end - | _ => None - end - | VLossfun n _ v1 v2 s l re => fun grad => - match df_eval σ l with - | Some le => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - let s' := df_deriv s xv1 in - let ograd := - vmap (fun '(lei, rei) => - let senv := cons (mk_env_entry xv1 lei) - (cons (mk_env_entry xv2 rei) nil) in - match df_eval senv s' with - | Some se => Some (grad * se) - | _ => None - end) - (vector_zip le re) in - match vectoro_to_ovector ograd with - | Some grad' => df_eval_backprop_deriv σ l grad_env grad' - | _ => None - end - | _ => None - end - | MLossfun n m _ v1 v2 s l re => fun grad => - match df_eval σ l with - | Some le => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - let s' := df_deriv s xv1 in - let ograd := - mmap (fun '(lei, rei) => - let senv := cons (mk_env_entry xv1 lei) - (cons (mk_env_entry xv2 rei) nil) in - match df_eval senv s' with - | Some se => Some ((grad * se)/(FfromZ (Z.of_nat m))) - | _ => None - end) - (matrix_zip le re) in - match matrixo_to_omatrix ograd with - | Some grad' => df_eval_backprop_deriv σ l grad_env grad' - | _ => None - end - | _ => None - end - end. - - Definition lifted_type (B:Type) T - := match T with - | DTfloat => B - | DTVector n => Vector B n - | DTMatrix m n => Matrix B m n - end. - - Fixpoint df_eval_tree_backprop_deriv {T} (σ:df_env) (df:DefinedFunction EvalAnn T) (grad_env:df_env) {struct df} : definition_function_types_interp T -> option df_env - := match df with - | Number _ _ => fun grad => Some grad_env - | Constant _ _ _ => fun grad => Some grad_env - | DVector n _ dfs => fun grad => - two_vector_env_iter_alt (fun x g genv => df_eval_tree_backprop_deriv σ x genv g) - grad_env dfs grad - | DMatrix n m _ dfs => fun grad => - two_matrix_env_iter_alt (fun x g genv => df_eval_tree_backprop_deriv σ x genv g) - grad_env dfs grad - | Var x _ => fun grad => - if vartlookup grad_env x then - Some (vart_update grad_env x (addvar x grad_env grad)) - else Some grad_env - | Plus _ l r => fun grad => - match df_eval_tree_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' grad - | _ => None - end - | Minus _ l r => fun grad => - match df_eval_tree_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (-grad) - | _ => None - end - | Times _ l r => fun grad => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_backprop_deriv σ l grad_env (re * grad) with - | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (le * grad) - | _ => None - end - | Divide _ l r => fun grad => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_backprop_deriv σ l grad_env (grad / re) with - | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (- le / (re * re) * grad) - | _ => None - end - | Square _ e => fun grad => - let ee := get_annotation e in - df_eval_tree_backprop_deriv σ e grad_env (2 * ee * grad) - | Exp _ e => fun grad => - let ee := get_annotation e in - df_eval_tree_backprop_deriv σ e grad_env (grad * Fexp ee) - | Log _ e => fun grad => - let ee := get_annotation e in - df_eval_tree_backprop_deriv σ e grad_env (grad / ee) - | Abs _ e => fun grad => - let ee := get_annotation e in - df_eval_tree_backprop_deriv σ e grad_env (grad * (sign ee)) - | Sign _ e => fun grad => df_eval_tree_backprop_deriv σ e grad_env 0 - | PSign _ e => fun grad => df_eval_tree_backprop_deriv σ e grad_env 0 - | Max _ l r => fun grad => - let '(le,re) := (get_annotation l, get_annotation r) in - if le <= re then - (df_eval_tree_backprop_deriv σ r grad_env grad) else - (df_eval_tree_backprop_deriv σ l grad_env grad) - | VectorElem n _ l i => fun grad => - let grad' := fun k => if proj1_sig k == proj1_sig i then grad else 0 in - df_eval_tree_backprop_deriv σ l grad_env grad' - | MatrixElem m n _ l i j => fun grad => - let grad' := fun k1 k2 => - if (proj1_sig k1 == proj1_sig i) then - if (proj1_sig k2 == proj1_sig j) then grad else 0 - else 0 in - df_eval_tree_backprop_deriv σ l grad_env grad' - | VectorDot n _ l r => fun grad => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_backprop_deriv σ l grad_env (vmap (fun rv => rv*grad) re) with - | Some grad_env' => - df_eval_tree_backprop_deriv σ r grad_env' (vmap (fun lv => lv*grad) le) - | _ => None - end - | VectorSum n _ l => fun grad => - df_eval_tree_backprop_deriv σ l grad_env (ConstVector n grad) - | MatrixSum n m _ l => fun grad => - df_eval_tree_backprop_deriv σ l grad_env (ConstMatrix n m grad) - | VectorScalMult n _ x r => fun grad => - let '(xe,re) := (get_annotation x, get_annotation r) in - match df_eval_tree_backprop_deriv σ x grad_env (vsum (fun j => (re j) * (grad j))) with - | Some grad_env' => - df_eval_tree_backprop_deriv σ r grad_env' (fun j => xe * (grad j)) - | _ => None - end - | MatrixScalMult n m _ x r => fun grad => - let '(xe,re) := (get_annotation x, get_annotation r) in - match df_eval_tree_backprop_deriv σ x grad_env (msum (fun i j => (re i j) * (grad i j))) with - | Some grad_env' => - df_eval_tree_backprop_deriv σ r grad_env' (fun i j => (grad i j) * xe) - | _ => None - end - | MatrixVectorMult n m _ l r => fun grad => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_backprop_deriv σ l grad_env (fun i j => (grad i) * (re j)) with - | Some grad_env' => - df_eval_tree_backprop_deriv σ r grad_env' (matrix_vector_mult (fun i j => le j i) grad) - | _ => None - end - | MatrixVectorAdd n m _ l r => fun grad => - match df_eval_tree_backprop_deriv σ l grad_env grad with - | Some grad_env' => - match list_env_iter - (fun i env => df_eval_tree_backprop_deriv σ r env ((transpose grad) i)) - (Some grad_env') (bounded_seq0 m) with - | Some grad_env'' => Some grad_env'' - | _ => None - end - | _ => None - end - | MatrixMult n m p _ l r => fun grad => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_backprop_deriv σ l grad_env (matrix_mult grad (fun i j => (re j i))) with - | Some grad_env' => - df_eval_tree_backprop_deriv σ r grad_env' (matrix_mult (fun i j => le j i) grad) - | _ => None - end - | VectorPlus n _ l r => fun grad => - match df_eval_tree_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' grad - | _ => None - end - | VectorMinus n _ l r => fun grad => - match df_eval_tree_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (fun i => - (grad i)) - | _ => None - end - | MatrixPlus n m _ l r => fun grad => - match df_eval_tree_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' grad - | _ => None - end - | MatrixMinus n m _ l r => fun grad => - match df_eval_tree_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (fun i j => - (grad i j)) - | _ => None - end - | VectorApply n _ x s r => fun grad => - let re := get_annotation r in - let xv := (x, DTfloat):var_type in - let s' := df_deriv s xv in - let ograd := - vmap (fun '(rei, g) => - match df_eval (cons (mk_env_entry xv rei) nil) s' with - | Some se => Some (g * se) - | _ => None - end) - (vector_zip re grad) in - match vectoro_to_ovector ograd with - | Some grad' => df_eval_tree_backprop_deriv σ r grad_env grad' - | _ => None - end - | MatrixApply n m _ x s r => fun grad => - let re := get_annotation r in - let xv := (x, DTfloat):var_type in - let s' := df_deriv s xv in - let ograd := - mmap (fun '(rei, g) => - match df_eval (cons (mk_env_entry xv rei) nil) s' with - | Some se => Some (g * se) - | _ => None - end) - (matrix_zip re grad) in - match matrixo_to_omatrix ograd with - | Some grad' => df_eval_tree_backprop_deriv σ r grad_env grad' - | _ => None - end - | VLossfun n _ v1 v2 s l re => fun grad => - let le := get_annotation l in - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - let s' := df_deriv s xv1 in - let ograd := - vmap (fun '(lei, rei) => - let senv := cons (mk_env_entry xv1 lei) - (cons (mk_env_entry xv2 rei) nil) in - match df_eval senv s' with - | Some se => Some (grad * se) - | _ => None - end) - (vector_zip le re) in - match vectoro_to_ovector ograd with - | Some grad' => df_eval_tree_backprop_deriv σ l grad_env grad' - | _ => None - end - | MLossfun n m _ v1 v2 s l re => fun grad => - let le := get_annotation l in - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - let s' := df_deriv s xv1 in - let ograd := - mmap (fun '(lei, rei) => - let senv := cons (mk_env_entry xv1 lei) - (cons (mk_env_entry xv2 rei) nil) in - match df_eval senv s' with - | Some se => Some ((grad * se) / (FfromZ (Z.of_nat m))) - | _ => None - end) - (matrix_zip le re) in - match matrixo_to_omatrix ograd with - | Some grad' => df_eval_tree_backprop_deriv σ l grad_env grad' - | _ => None - end - end. - - Definition o_df_env_to_df_env (oenv : option df_env) : df_env := - match oenv with - | Some env => env - | _ => nil - end. - - - Definition backprop_lookup (oenv:option df_env) (a:var_type) : - option (definition_function_types_interp (snd a)) := - match oenv with - | Some env => - match vartlookup env a with - | Some val => Some val - | _ => None - end - | _ => None - end. - - Definition is_scalar_df_type (dft:definition_function_types) : Prop - := match dft with - | DTfloat => True - | _ => False - end. - - Fixpoint is_scalar_function {Ann} {T} (df:DefinedFunction Ann T) : Prop - := match df with - | Number _ _ => True - | Constant t _ _ => is_scalar_df_type t - | Var v _ => is_scalar_df_type (snd v) - | Plus _ l r => is_scalar_function l /\ is_scalar_function r - | Minus _ l r => is_scalar_function l /\ is_scalar_function r - | Times _ l r => is_scalar_function l /\ is_scalar_function r - | Divide _ l r => is_scalar_function l /\ is_scalar_function r - | Square _ e => is_scalar_function e - | Exp _ e => is_scalar_function e - | Log _ e => is_scalar_function e - | Abs _ e => is_scalar_function e - | Sign _ e => is_scalar_function e - | PSign _ e => is_scalar_function e - | Max _ l r => is_scalar_function l /\ is_scalar_function r - | _ => False - end. - - Fixpoint has_scalar_functions {Ann} {T} - (df:DefinedFunction Ann T) {struct df}: Prop - := match df with - | Number _ _ => True - | Constant _ _ _ => True - | DVector n _ vec => - vforall (has_scalar_functions) vec - | DMatrix n m _ mat => - vforall (vforall (has_scalar_functions)) mat - | Var _ _ => True - | Plus _ l r => (has_scalar_functions l) /\ (has_scalar_functions r) - | Minus _ l r => (has_scalar_functions l) /\ (has_scalar_functions r) - | Times _ l r => (has_scalar_functions l) /\ (has_scalar_functions r) - | Divide _ l r => (has_scalar_functions l) /\ (has_scalar_functions r) - | Square _ l => has_scalar_functions l - | Exp _ l => has_scalar_functions l - | Log _ l => has_scalar_functions l - | Abs _ l => has_scalar_functions l - | Sign _ l => has_scalar_functions l - | PSign _ l => has_scalar_functions l - | Max _ l r => (has_scalar_functions l) /\ (has_scalar_functions r) - | VectorDot n _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | VectorSum _ _ l => has_scalar_functions l - | MatrixSum _ _ _ l => has_scalar_functions l - | VectorElem _ _ vec i => has_scalar_functions vec - | MatrixElem _ _ _ mat i j => has_scalar_functions mat - | MatrixVectorMult _ _ _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | MatrixVectorAdd _ _ _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | MatrixMult _ _ _ _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | VectorPlus _ _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | VectorMinus _ _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | MatrixPlus _ _ _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | MatrixMinus _ _ _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | VectorScalMult _ _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | MatrixScalMult _ _ _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | VectorApply _ _ _ s l => is_scalar_function s /\ has_scalar_functions l - | MatrixApply _ _ _ _ s l => is_scalar_function s /\ has_scalar_functions l - | VLossfun _ _ _ _ s l _ => is_scalar_function s /\ has_scalar_functions l - | MLossfun _ _ _ _ _ s l _ => is_scalar_function s /\ has_scalar_functions l - end. - - Lemma is_scalar_function_has_scalar_functions {Ann} {T} (df:DefinedFunction Ann T) : - is_scalar_function df -> has_scalar_functions df. - Proof. - induction df; firstorder. - Qed. - - Hint Resolve is_scalar_function_has_scalar_functions : fml. - - Definition DefinedFunction_ind_unit_has_scalar_functions - (P : forall (d : definition_function_types), DefinedFunction UnitAnn d -> Prop) - (f : forall (ann : UnitAnn DTfloat) (x : float), - P DTfloat (Number ann x)) - (f0 : forall (t : definition_function_types) - (ann : UnitAnn t) (x : definition_function_types_interp t), P t (Constant ann x)) - (f1 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (x : Vector (DefinedFunction UnitAnn DTfloat) n), - (forall s : {n' : nat | (n' < n)%nat}, P DTfloat (x s)) -> - P (DTVector n) (DVector ann x)) - (f2 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m)) - (x : Matrix (DefinedFunction UnitAnn DTfloat) n m), - (forall (s : {n' : nat | (n' < n)%nat}) (s0 : {m' : nat | (m' < m)%nat}), - P DTfloat (x s s0)) -> P (DTMatrix n m) (DMatrix ann x)) - (f3 : forall (v : var_type) (ann : UnitAnn (snd v)), - P (snd v) (Var v ann)) - (f4 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Plus ann l r)) - (f5 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Minus ann l r)) - (f6 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Times ann l r)) - (f7 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Divide ann l r)) - (f8 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Square ann e)) - (f9 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Exp ann e)) - (f10 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Log ann e)) - (f11 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Abs ann e)) - (f12 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Sign ann e)) - (f13 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (PSign ann e)) - (f14 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Max ann l r)) - (f15 : forall (n : nat) (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn (DTVector n)), - P (DTVector n) l -> - forall r : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) r -> P DTfloat (VectorDot ann l r)) - (f16 : forall (n : nat) (ann : UnitAnn DTfloat) - (v : DefinedFunction UnitAnn (DTVector n)), P (DTVector n) v -> P DTfloat (VectorSum ann v)) - (f17 : forall (m n : nat) (ann : UnitAnn DTfloat) - (v : DefinedFunction UnitAnn (DTMatrix m n)), - P (DTMatrix m n) v -> P DTfloat (MatrixSum ann v)) - (f18 : forall (n : nat) (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn (DTVector n)), - P (DTVector n) l -> - forall i : {x : nat | (x < n)%nat}, P DTfloat (VectorElem ann l i)) - (f19 : forall (m n : nat) (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn (DTMatrix m n)), - P (DTMatrix m n) l -> - forall (i : {x : nat | (x < m)%nat}) (j : {x : nat | (x < n)%nat}), - P DTfloat (MatrixElem ann l i j)) - (f20 : forall (m n : nat) (ann : UnitAnn (DTVector m)) - (l : DefinedFunction UnitAnn (DTMatrix m n)), - P (DTMatrix m n) l -> - forall r : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) r -> P (DTVector m) (MatrixVectorMult ann l r)) - (f21 : forall (m n : nat) (ann : UnitAnn (DTMatrix m n)) - (l : DefinedFunction UnitAnn (DTMatrix m n)), - P (DTMatrix m n) l -> - forall r : DefinedFunction UnitAnn (DTVector m), - P (DTVector m) r -> P (DTMatrix m n) (MatrixVectorAdd ann l r)) - (f22 : forall (m p n : nat) (ann : UnitAnn (DTMatrix m n)) - (l : DefinedFunction UnitAnn (DTMatrix m p)), - P (DTMatrix m p) l -> - forall r : DefinedFunction UnitAnn (DTMatrix p n), - P (DTMatrix p n) r -> P (DTMatrix m n) (MatrixMult ann l r)) - (f23 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (l : DefinedFunction UnitAnn (DTVector n)), - P (DTVector n) l -> - forall r : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) r -> P (DTVector n) (VectorPlus ann l r)) - (f24 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (l : DefinedFunction UnitAnn (DTVector n)), - P (DTVector n) l -> - forall r : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) r -> P (DTVector n) (VectorMinus ann l r)) - (f25 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m)) - (l : DefinedFunction UnitAnn (DTMatrix n m)), - P (DTMatrix n m) l -> - forall r : DefinedFunction UnitAnn (DTMatrix n m), - P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixPlus ann l r)) - (f26 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m)) - (l : DefinedFunction UnitAnn (DTMatrix n m)), - P (DTMatrix n m) l -> - forall r : DefinedFunction UnitAnn (DTMatrix n m), - P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixMinus ann l r)) - (f27 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (x : DefinedFunction UnitAnn DTfloat), - P DTfloat x -> - forall l : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) l -> P (DTVector n) (VectorScalMult ann x l)) - (f28 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m)) - (x : DefinedFunction UnitAnn DTfloat), - P DTfloat x -> - forall l : DefinedFunction UnitAnn (DTMatrix n m), - P (DTMatrix n m) l -> P (DTMatrix n m) (MatrixScalMult ann x l)) - (f29 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (v : SubVar) (s : DefinedFunction UnitAnn DTfloat), - is_scalar_function s -> - P DTfloat s -> - forall l : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) l -> P (DTVector n) (VectorApply ann v s l)) - (f30 : forall (m n : nat) (ann : UnitAnn (DTMatrix m n)) - (v : SubVar) (s : DefinedFunction UnitAnn DTfloat), - is_scalar_function s -> - P DTfloat s -> - forall l : DefinedFunction UnitAnn (DTMatrix m n), - P (DTMatrix m n) l -> P (DTMatrix m n) (MatrixApply ann v s l)) - (f31 : forall (n : nat) (ann : UnitAnn DTfloat) - (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat), - is_scalar_function s -> - P DTfloat s -> - forall l : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) l -> forall r : Vector float n, P DTfloat (VLossfun ann v1 v2 s l r)) - (f32 : forall (m n : nat) (ann : UnitAnn DTfloat) - (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat), - is_scalar_function s -> - P DTfloat s -> - forall l : DefinedFunction UnitAnn (DTMatrix m n), - P (DTMatrix m n) l -> - forall r : Matrix float m n, P DTfloat (MLossfun ann v1 v2 s l r)) - : forall (d : definition_function_types) - (d0 : DefinedFunction UnitAnn d) - (hs:has_scalar_functions d0), P d d0. - Proof. - refine (fix - F (d : definition_function_types) - (d0 : DefinedFunction UnitAnn d) - {struct d0} : has_scalar_functions d0 -> P d d0 := - match d0 as d2 in (DefinedFunction _ d1) return has_scalar_functions d2 -> (P d1 d2) with - | Number ann x => fun hs => f ann x - | @Constant _ t ann x => fun hs => f0 t ann x - | @DVector _ n ann x => fun hs => f1 n ann x (fun s : {n' : nat | (n' < n)%nat} => F DTfloat (x s) _) - | @DMatrix _ n m ann x => fun hs => - f2 n m ann x - (fun (s : {n' : nat | (n' < n)%nat}) (s0 : {m' : nat | (m' < m)%nat}) => - F DTfloat (x s s0) _) - | Var v ann => fun hs => f3 v ann - | Plus ann l r => fun hs => f4 ann l (F DTfloat l (proj1 hs)) r (F DTfloat r (proj2 hs)) - | Minus ann l r => fun hs => f5 ann l (F DTfloat l _) r (F DTfloat r _) - | Times ann l r => fun hs => f6 ann l (F DTfloat l _) r (F DTfloat r _) - | Divide ann l r => fun hs => f7 ann l (F DTfloat l _) r (F DTfloat r _) - | Square ann e => fun hs => f8 ann e (F DTfloat e _) - | Exp ann e => fun hs => f9 ann e (F DTfloat e _) - | Log ann e => fun hs => f10 ann e (F DTfloat e _) - | Abs ann e => fun hs => f11 ann e (F DTfloat e _) - | Sign ann e => fun hs => f12 ann e (F DTfloat e _) - | PSign ann e => fun hs => f13 ann e (F DTfloat e _) - | Max ann l r => fun hs => f14 ann l (F DTfloat l _) r (F DTfloat r _) - | @VectorDot _ n ann l r => fun hs => f15 n ann l (F (DTVector n) l _) r (F (DTVector n) r _) - | @VectorSum _ n ann v => fun hs => f16 n ann v (F (DTVector n) v _) - | @MatrixSum _ m n ann v => fun hs => f17 m n ann v (F (DTMatrix m n) v _) - | @VectorElem _ n ann l i => fun hs => f18 n ann l (F (DTVector n) l _) i - | @MatrixElem _ m n ann l i j => fun hs => f19 m n ann l (F (DTMatrix m n) l _) i j - | @MatrixVectorMult _ m n ann l r => fun hs => - f20 m n ann l (F (DTMatrix m n) l _) r (F (DTVector n) r _) - | @MatrixVectorAdd _ m n ann l r => fun hs => - f21 m n ann l (F (DTMatrix m n) l _) r (F (DTVector m) r _) - | @MatrixMult _ m p n ann l r => fun hs => - f22 m p n ann l (F (DTMatrix m p) l _) r (F (DTMatrix p n) r _) - | @VectorPlus _ n ann l r => fun hs => f23 n ann l (F (DTVector n) l _) r (F (DTVector n) r _) - | @VectorMinus _ n ann l r => fun hs => f24 n ann l (F (DTVector n) l _) r (F (DTVector n) r _) - | @MatrixPlus _ n m ann l r => fun hs => f25 n m ann l (F (DTMatrix n m) l _) r (F (DTMatrix n m) r _) - | @MatrixMinus _ n m ann l r => fun hs => - f26 n m ann l (F (DTMatrix n m) l _) r (F (DTMatrix n m) r _) - | @VectorScalMult _ n ann x l => fun hs => f27 n ann x (F DTfloat x _) l (F (DTVector n) l _) - | @MatrixScalMult _ n m ann x l => fun hs => f28 n m ann x (F DTfloat x _) l (F (DTMatrix n m) l _) - | @VectorApply _ n ann v s l => fun hs => f29 n ann v s _ (F DTfloat s _) l (F (DTVector n) l _) - | @MatrixApply _ m n ann v s l => fun hs => - f30 m n ann v s _ (F DTfloat s _) l (F (DTMatrix m n) l _) - | @VLossfun _ n ann v1 v2 s l r => fun hs => - f31 n ann v1 v2 s _ (F DTfloat s _) l (F (DTVector n) l _) r - | @MLossfun _ m n ann v1 v2 s l r => fun hs => - f32 m n ann v1 v2 s _ (F DTfloat s _) l (F (DTMatrix m n) l _) r - end); simpl in hs; intuition. - - exact (proj1 (vforall_forall has_scalar_functions x) hs s). - - rewrite vforall_forall in hs. - specialize (hs s). - rewrite vforall_forall in hs. - specialize (hs s0). - exact hs. - Defined. - - Fixpoint is_df_rec_prop {Ann} {T} - (prop : forall TT:definition_function_types, - (DefinedFunction Ann TT) -> Prop) - (df:DefinedFunction Ann T) {struct df}: Prop - := prop T df /\ - match df with - | Number _ _ => True - | Constant _ _ _ => True - | DVector n _ vec => - vforall (is_df_rec_prop prop) vec - | DMatrix n m _ mat => - vforall (vforall (is_df_rec_prop prop)) mat - | Var _ _ => True - | Plus _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r) - | Minus _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r) - | Times _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r) - | Divide _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r) - | Square _ l => is_df_rec_prop prop l - | Exp _ l => is_df_rec_prop prop l - | Log _ l => is_df_rec_prop prop l - | Abs _ l => is_df_rec_prop prop l - | Sign _ l => is_df_rec_prop prop l - | PSign _ l => is_df_rec_prop prop l - | Max _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r) - | VectorDot n _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | VectorSum _ _ l => is_df_rec_prop prop l - | MatrixSum _ _ _ l => is_df_rec_prop prop l - | VectorElem _ _ vec i => is_df_rec_prop prop vec - | MatrixElem _ _ _ mat i j => is_df_rec_prop prop mat - | MatrixVectorMult _ _ _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | MatrixVectorAdd _ _ _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | MatrixMult _ _ _ _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | VectorPlus _ _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | VectorMinus _ _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | MatrixPlus _ _ _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | MatrixMinus _ _ _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | VectorScalMult _ _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | MatrixScalMult _ _ _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | VectorApply _ _ _ _ l => is_df_rec_prop prop l - | MatrixApply _ _ _ _ _ l => is_df_rec_prop prop l - | VLossfun _ _ _ _ _ l _ => is_df_rec_prop prop l - | MLossfun _ _ _ _ _ _ l _ => is_df_rec_prop prop l - end. - - Fixpoint df_strip_annotations {Ann} {T} - (df:DefinedFunction Ann T) {struct df}: DefinedFunction UnitAnn T - := - match df with - | Number _ x1 => Number tt x1 - | Constant t _ x => Constant tt x - | DVector n _ vec => DVector tt (vmap df_strip_annotations vec) - | DMatrix n m _ mat => DMatrix tt (vmap (vmap df_strip_annotations) mat) - | Var v _ => Var v tt - | Plus _ l r => Plus tt (df_strip_annotations l) (df_strip_annotations r) - | Minus _ l r => Minus tt (df_strip_annotations l) (df_strip_annotations r) - | Times _ l r => Times tt (df_strip_annotations l) (df_strip_annotations r) - | Divide _ l r => Divide tt (df_strip_annotations l) (df_strip_annotations r) - | Square _ l => Square tt (df_strip_annotations l) - | Exp _ l => Exp tt (df_strip_annotations l) - | Log _ l => Log tt (df_strip_annotations l) - | Abs _ l => Abs tt (df_strip_annotations l) - | Sign _ l => Sign tt (df_strip_annotations l) - | PSign _ l => PSign tt (df_strip_annotations l) - | Max _ l r => Max tt (df_strip_annotations l) (df_strip_annotations r) - | VectorDot n _ l r => VectorDot tt (df_strip_annotations l) (df_strip_annotations r) - | VectorSum n _ l => VectorSum tt (df_strip_annotations l) - | MatrixSum m n _ l => MatrixSum tt (df_strip_annotations l) - | VectorElem n _ vec i => VectorElem tt (df_strip_annotations vec) i - | MatrixElem m n _ mat i j => MatrixElem tt (df_strip_annotations mat) i j - | MatrixVectorMult m n _ l r => MatrixVectorMult tt (df_strip_annotations l) (df_strip_annotations r) - | MatrixVectorAdd m n _ l r => MatrixVectorAdd tt (df_strip_annotations l) (df_strip_annotations r) - | MatrixMult m p n _ l r => MatrixMult tt (df_strip_annotations l) (df_strip_annotations r) - | VectorPlus n _ l r => VectorPlus tt (df_strip_annotations l) (df_strip_annotations r) - | VectorMinus n _ l r => VectorMinus tt (df_strip_annotations l) (df_strip_annotations r) - | MatrixPlus m n _ l r => MatrixPlus tt (df_strip_annotations l) (df_strip_annotations r) - | MatrixMinus m n _ l r => MatrixMinus tt (df_strip_annotations l) (df_strip_annotations r) - | VectorScalMult n _ l r => VectorScalMult tt (df_strip_annotations l) (df_strip_annotations r) - | MatrixScalMult m n _ l r => MatrixScalMult tt (df_strip_annotations l) (df_strip_annotations r) - | VectorApply n _ v s l => VectorApply tt v (df_strip_annotations s) (df_strip_annotations l) - | MatrixApply m n _ v s l => MatrixApply tt v (df_strip_annotations s) (df_strip_annotations l) - | VLossfun n _ v1 v2 s l r => VLossfun tt v1 v2 (df_strip_annotations s) (df_strip_annotations l) r - | MLossfun m n _ v1 v2 s l r => MLossfun tt v1 v2 (df_strip_annotations s) (df_strip_annotations l) r - end. - - - Lemma df_strip_annotations_id {T} (df:DefinedFunction UnitAnn T) : df_strip_annotations df = df. - Proof. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case; simpl; trivial - ; destruct ann; trivial; try congruence. - - Case "DVector"%string. - f_equal. - erewrite vmap_ext; [apply vmap_id | ]; intros. - simpl. - destruct H0 as [??]; subst. - eapply H; eauto. - - Case "DMatrix"%string. - f_equal. - erewrite vmap_ext; [apply vmap_id | ]; intros. - simpl. - erewrite vmap_ext; [apply vmap_id | ]; intros. - simpl. - destruct H0 as [??]; subst. - destruct H1 as [??]; subst. - eapply H; eauto. - Qed. - - Definition df_eq_upto_annotations {Ann1 Ann2 T} - (df1:DefinedFunction Ann1 T) (df2:DefinedFunction Ann2 T) : Prop - := df_strip_annotations df1 = df_strip_annotations df2. - - Definition is_df_evalann_correct_top (σ:df_env) {T} (df:DefinedFunction EvalAnn T) - := df_eval σ df = Some (get_annotation df). - - Definition is_df_evalann_correct (σ:df_env) {T} (df:DefinedFunction EvalAnn T) - := is_df_rec_prop (@is_df_evalann_correct_top σ) df. - - Lemma is_df_rec_prop_top {Ann} {T} - {prop : forall TT:definition_function_types, - (DefinedFunction Ann TT) -> Prop} - {df:DefinedFunction Ann T} : - is_df_rec_prop prop df -> - prop _ df. - Proof. - destruct df; simpl; tauto. - Qed. - - Lemma df_eval_tree_correct {T Ann} (σ:df_env) (df:DefinedFunction Ann T) (dfann:DefinedFunction EvalAnn T): - df_eval_tree σ df = Some dfann -> - is_df_evalann_correct σ dfann. - Proof. - unfold is_df_evalann_correct, is_df_evalann_correct_top. - revert dfann. - DefinedFunction_cases (induction df) Case; simpl; intros dfann eqq - ; try solve[case_eq (df_eval_tree σ df1) - ; [intros adf1 a1eqq | intros a1eqq] - ; rewrite a1eqq in eqq - ; [| congruence] - ; (case_eq (df_eval_tree σ df2) - ; [intros adf2 a2eqq | intros a2eqq] - ; rewrite a2eqq in eqq - ; [| congruence] - ; inversion eqq; simpl - ; specialize (IHdf1 _ a1eqq) - ; specialize (IHdf2 _ a2eqq) - ; split; [| tauto] - ; apply is_df_rec_prop_top in IHdf1 - ; apply is_df_rec_prop_top in IHdf2 - ; simpl in IHdf1, IHdf2 - ; rewrite IHdf1, IHdf2 - ; trivial) - - | - case_eq (df_eval_tree σ df) - ; [intros adf aeqq | intros aeqq] - ; rewrite aeqq in eqq - ; [| congruence] - ; inversion eqq; simpl - ; specialize (IHdf _ aeqq) - ; split; [| tauto] - ; apply is_df_rec_prop_top in IHdf - ; simpl in IHdf - ; rewrite IHdf - ; trivial - ]. - - - Case "Number"%string. - inversion eqq; subst. - simpl; tauto. - - Case "Constant"%string. - inversion eqq; subst. - simpl; tauto. - - Case "DVector"%string. - match_option_in eqq. - invcs eqq. - simpl. - specialize (vectoro_to_ovector_forall_some_f eqq0) - ; simpl - ; clear eqq0; intros eqq0. - split. - + apply vectoro_to_ovector_forall_some_b_strong; intros i. - specialize (H _ _ (eqq0 i)). - apply is_df_rec_prop_top in H. - simpl in *. - rewrite vmap_nth; trivial. - + apply vforall_forall; eauto. - - Case "DMatrix"%string. - match_option_in eqq. - invcs eqq. - simpl. - unfold matrixo_to_omatrix in *. - specialize (vectoro_to_ovector_forall_some_f eqq0) - ; simpl - ; clear eqq0; intros eqq0. - split. - + apply vectoro_to_ovector_forall_some_b_strong; intros i. - apply vectoro_to_ovector_forall_some_b_strong; intros j. - specialize (eqq0 i). - specialize (vectoro_to_ovector_forall_some_f eqq0) - ; simpl - ; clear eqq0; intros eqq0. - specialize (eqq0 j). - specialize (H _ _ _ eqq0). - apply is_df_rec_prop_top in H. - simpl in *. - repeat rewrite vmap_nth; trivial. - + apply vforall_forall; intros. - apply vforall_forall; intros. - specialize (eqq0 i). - specialize (vectoro_to_ovector_forall_some_f eqq0) - ; simpl - ; clear eqq0; intros eqq0. - eauto. - - Case "Var"%string. - revert eqq. - case_eq (vartlookup σ v) ; [| congruence]. - intros. - inversion eqq; subst; simpl. - rewrite H; tauto. - - Case "VectorApply"%string. - case_eq (df_eval_tree σ df2) - ; [intros adf2 a2eqq | intros a2eqq] - ; rewrite a2eqq in eqq - ; [| congruence]. - specialize (IHdf2 _ a2eqq). - match_option_in eqq. - invcs eqq. - simpl. - split; trivial. - apply is_df_rec_prop_top in IHdf2. - simpl in IHdf2. - rewrite IHdf2; trivial. - - Case "MatrixApply"%string. - case_eq (df_eval_tree σ df2) - ; [intros adf2 a2eqq | intros a2eqq] - ; rewrite a2eqq in eqq - ; [| congruence]. - specialize (IHdf2 _ a2eqq). - match_option_in eqq. - invcs eqq. - simpl. - split; trivial. - apply is_df_rec_prop_top in IHdf2. - simpl in IHdf2. - rewrite IHdf2; trivial. - - Case "VLossfun"%string. - case_eq (df_eval_tree σ df2) - ; [intros adf2 a2eqq | intros a2eqq] - ; rewrite a2eqq in eqq - ; [| congruence]. - specialize (IHdf2 _ a2eqq). - match_option_in eqq. - invcs eqq. - simpl. - split; trivial. - apply is_df_rec_prop_top in IHdf2. - simpl in IHdf2. - rewrite IHdf2, eqq0; trivial. - - Case "MLossfun"%string. - case_eq (df_eval_tree σ df2) - ; [intros adf2 a2eqq | intros a2eqq] - ; rewrite a2eqq in eqq - ; [| congruence]. - specialize (IHdf2 _ a2eqq). - match_option_in eqq. - invcs eqq. - simpl. - split; trivial. - apply is_df_rec_prop_top in IHdf2. - simpl in IHdf2. - rewrite IHdf2, eqq0; trivial. - Qed. - - - Lemma df_eval_tree_deriv_correct {T} {σ:df_env} {df:DefinedFunction EvalAnn T} : - is_df_evalann_correct σ df -> - forall (xv:var_type), - (* let xv := (v, DTfloat) in *) - df_eval_tree_deriv σ df xv = df_eval_deriv σ df xv. - Proof. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case; - intro iscor; destruct iscor; - simpl; intros; trivial; unfold is_df_evalann_correct in * - ; try solve - [ - rewrite IHdf - ; trivial - | - assert (is_df_evalann_correct_top σ df) - ; [ apply is_df_rec_prop_top; trivial | - unfold is_df_evalann_correct_top in H1 - ; rewrite H1 - ; rewrite IHdf - ; trivial - ] - | - rewrite IHdf1; - [ rewrite IHdf2 - ; trivial - ; tauto - | - tauto - ] - | - rewrite IHdf1; - [ assert (is_df_evalann_correct_top σ df2); - [ apply is_df_rec_prop_top; trivial - | unfold is_df_evalann_correct_top in H1; - rewrite H1; trivial] - | tauto] - | - destruct H0; rewrite IHdf1; - [rewrite IHdf2; - [assert (is_df_evalann_correct_top σ df1); - [apply is_df_rec_prop_top; trivial - | assert (is_df_evalann_correct_top σ df2); - [ apply is_df_rec_prop_top; trivial - | - unfold is_df_evalann_correct_top in H2; - unfold is_df_evalann_correct_top in H3; - rewrite H2; rewrite H3; trivial ]] - | - tauto] - | - tauto] - ]. - - Case "DVector"%string. - f_equal. - apply FunctionalExtensionality.functional_extensionality. - intro. - apply H. - destruct H0. - rewrite vforall_forall in H1. - eauto. - - Case "DMatrix"%string. - f_equal. - apply FunctionalExtensionality.functional_extensionality. - intro. - apply FunctionalExtensionality.functional_extensionality. - intros. - apply H. - destruct H0. - rewrite vforall_forall in H1. - specialize (H1 x0). - rewrite vforall_forall in H1. - eauto. - Qed. - - Lemma df_eval_tree_backprop_deriv_correct {T} (σ gradenv:df_env) (df:DefinedFunction EvalAnn T) (grad : definition_function_types_interp T) : - is_df_evalann_correct σ df -> - df_eval_tree_backprop_deriv σ df gradenv grad = df_eval_backprop_deriv σ df gradenv grad. - Proof. - revert gradenv grad. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case; - intros; simpl;trivial - ; try solve [ - destruct H; - assert (is_df_evalann_correct σ df); - [ unfold is_df_evalann_correct; trivial - | apply is_df_rec_prop_top in H0; - rewrite IHdf; - [ unfold is_df_evalann_correct_top in H0; - rewrite H0; trivial - | trivial]] - | - destruct H; - apply IHdf; - unfold is_df_evalann_correct; trivial - | - destruct H; destruct H0; rewrite IHdf1; - [ case_eq (df_eval_backprop_deriv σ df1 gradenv grad); [|congruence]; - intros; - rewrite IHdf2; trivial - | unfold is_df_evalann_correct; trivial] - - | - destruct H; - assert (is_df_evalann_correct σ df2); - [ unfold is_df_evalann_correct; trivial - | apply is_df_rec_prop_top in H0; - unfold is_df_evalann_correct_top in H0; - rewrite H0; - match_destr; - rewrite IHdf1; trivial] - | - destruct H; destruct H0; assert (is_df_evalann_correct σ df1); - [ unfold is_df_evalann_correct; trivial - | assert (is_df_evalann_correct σ df2); - [ unfold is_df_evalann_correct; trivial - | rewrite IHdf1; trivial; - apply is_df_rec_prop_top in H0; - apply is_df_rec_prop_top in H1; - unfold is_df_evalann_correct_top in H0; - unfold is_df_evalann_correct_top in H1; - rewrite H0; rewrite H1; - match_destr; - rewrite IHdf2; trivial]] - ]. - - Case "DVector"%string. - destruct H0. - rewrite vforall_forall in H1. - unfold two_vector_env_iter_alt. - f_equal; apply FunctionalExtensionality.functional_extensionality; intros - ; apply FunctionalExtensionality.functional_extensionality ; intros. - apply H; unfold is_df_evalann_correct; apply H1. - - Case "DMatrix"%string. - destruct H0. - rewrite vforall_forall in H1. - unfold two_matrix_env_iter_alt. - f_equal; apply FunctionalExtensionality.functional_extensionality; intros - ; apply FunctionalExtensionality.functional_extensionality; intros. - f_equal; apply FunctionalExtensionality.functional_extensionality; intros - ; apply FunctionalExtensionality.functional_extensionality; intros. - apply H; unfold is_df_evalann_correct. - specialize (H1 x0). - rewrite vforall_forall in H1; apply H1. - - Case "MatrixVectorAdd"%string. - destruct H. - destruct H0. - assert (is_df_evalann_correct σ df1). - unfold is_df_evalann_correct; trivial. - assert (is_df_evalann_correct σ df2). - unfold is_df_evalann_correct; trivial. - rewrite IHdf1; trivial. - match_destr. - assert - (list_env_iter - (fun (i : {m' : nat | (m' < n)%nat}) (env : df_env) => - df_eval_tree_backprop_deriv σ df2 env (transpose grad i)) - (Some d) (bounded_seq0 n) = - list_env_iter - (fun (i : {m' : nat | (m' < n)%nat}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad i)) (Some d) - (bounded_seq0 n)). - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite IHdf2; trivial. - rewrite H4; trivial. - Qed. - - Lemma df_eval_ignores_ann {Ann T} {σ:df_env} - (df:DefinedFunction Ann T) : - df_eval σ df = df_eval σ (df_strip_annotations df). - Proof. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case; simpl; trivial - ; try solve [ - rewrite IHdf; trivial - | - rewrite IHdf1; - case_eq (df_eval σ (df_strip_annotations df1)); [|congruence]; - intros; rewrite IHdf2; trivial - | - rewrite IHdf1; trivial; - case_eq (df_eval σ (df_strip_annotations df2)); [|congruence]; - intros; f_equal; - apply FunctionalExtensionality.functional_extensionality; intros; - f_equal; rewrite df_strip_annotations_id; trivial - | - rewrite IHdf1; trivial; - case_eq (df_eval σ (df_strip_annotations df2)); [|congruence]; - intros; f_equal; - rewrite df_strip_annotations_id; trivial - ]. - - - Case "DVector"%string. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite H. - rewrite vmap_nth; trivial. - - Case "DMatrix"%string. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H x0). - rewrite H. - rewrite vmap_nth. - rewrite vmap_nth; trivial. - Qed. - - Lemma df_eval_ignores_ann2 {Ann1 Ann2 T} {σ:df_env} - (df1:DefinedFunction Ann1 T) (df2:DefinedFunction Ann2 T) : - df_eq_upto_annotations df1 df2 -> - df_eval σ df1 = df_eval σ df2. - Proof. - assert (df_eval σ df1 = df_eval σ (df_strip_annotations df1)) by apply df_eval_ignores_ann. - assert (df_eval σ df2 = df_eval σ (df_strip_annotations df2)) by apply df_eval_ignores_ann. - congruence. - Qed. - - Lemma df_eval_deriv_ignores_ann {Ann T} {σ:df_env} - (df:DefinedFunction Ann T) : - forall (xv:var_type), - df_eval_deriv σ df xv = df_eval_deriv σ (df_strip_annotations df) xv. - Proof. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case; simpl; trivial - ; try solve - [ - intro; rewrite IHdf1; - case_eq (df_eval_deriv σ (df_strip_annotations df1) xv); [|congruence]; - intros; - rewrite IHdf2; trivial - | - intro; rewrite df_eval_ignores_ann; - case_eq (df_eval σ (df_strip_annotations df1)); [|congruence]; - intros; rewrite IHdf1; intros; - case_eq (df_eval_deriv σ (df_strip_annotations df1) xv); [|congruence]; - intros; rewrite df_eval_ignores_ann; - case_eq (df_eval σ (df_strip_annotations df2)); [|congruence]; - intros; rewrite IHdf2; trivial - | - intros; rewrite df_eval_ignores_ann; - case_eq (df_eval σ (df_strip_annotations df)); [|congruence]; - intros; rewrite IHdf; intros; - case_eq (df_eval_deriv σ (df_strip_annotations df) xv); [|congruence]; - trivial - | - intros; rewrite IHdf; trivial - | - intro; rewrite df_eval_ignores_ann; - case_eq (df_eval σ (df_strip_annotations df2)); [|congruence]; - intros; rewrite IHdf1; intros; - rewrite df_strip_annotations_id; trivial - ]. - - - Case "DVector"%string. - intros. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite H. - rewrite vmap_nth; trivial. - - Case "DMatrix"%string. - intros. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H x0). - rewrite H. - rewrite vmap_nth. - rewrite vmap_nth; trivial. - - Case "Max"%string. - intro; rewrite df_eval_ignores_ann. - case_eq (df_eval σ (df_strip_annotations df1)); [|congruence]. - intros. - rewrite df_eval_ignores_ann. - case_eq (df_eval σ (df_strip_annotations df2)); [|congruence]. - intros. - rewrite IHdf1. - rewrite IHdf2; trivial. - Qed. - - Lemma df_eval_deriv_ignores_ann2 {Ann1 Ann2 T} {σ:df_env} - (df1:DefinedFunction Ann1 T) (df2:DefinedFunction Ann2 T) : - forall (xv:var_type), - df_eq_upto_annotations df1 df2 -> - df_eval_deriv σ df1 xv = df_eval_deriv σ df2 xv. - Proof. - intro. - assert (df_eval_deriv σ df1 xv = df_eval_deriv σ (df_strip_annotations df1) xv) by apply df_eval_deriv_ignores_ann. - assert (df_eval_deriv σ df2 xv = df_eval_deriv σ (df_strip_annotations df2) xv) by apply df_eval_deriv_ignores_ann. - congruence. - Qed. - - Lemma is_scalar_function_scalar {Ann} {T} (df:DefinedFunction Ann T) : - is_scalar_function df -> is_scalar_df_type T. - Proof. - induction df; simpl; trivial. - Qed. - - - - Definition definition_function_types_map_base (f:Type->Type) (dft:definition_function_types): Type - := match dft with - | DTfloat => f float - | DTVector n => Vector (f float) n - | DTMatrix m n => Matrix (f float) m n - end. - - Definition definition_function_types_subgradient (dft:definition_function_types) - := definition_function_types_map_base (fun t => list (list t)) dft. - - - Definition df_eval_gradient {T} σ (df:DefinedFunction UnitAnn T) (lv:list var_type) : option (list (definition_function_types_interp T)) - := listo_to_olist (map (df_eval_deriv σ df) lv). - - Definition combine_prod (l1 l2 : list (list float)) : list (list (float * float)) - := let l12 := list_prod l1 l2 - in map (fun '(x,y) => combine x y) l12. -(* - Fixpoint df_eval_subgradient {dft:definition_function_types} (σ:df_env) (df:DefinedFunction dft) (lv:list SubVar) : option (definition_function_types_subgradient dft) - := (match df with - | Number _ => Some ((map (fun _ => 0) lv) :: nil) - | DVector n v => vectoro_to_ovector (vmap (fun x => df_eval_subgradient σ x lv) v) - | DMatrix n m df => matrixo_to_omatrix (vmap (fun x => vmap (fun y => df_eval_subgradient σ y lv) x) df) - | Var x => Some ((map (fun v => if x == v then 1 else 0) lv) :: nil) - | Plus l r => - match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with - | Some ld, Some rd => Some (map (map (fun '(x, y) => x+y)) (combine_prod ld rd)) - | _, _ => None - end - | Minus l r => - match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with - | Some ld, Some rd => Some (map (map (fun '(x, y) => x-y)) (combine_prod ld rd)) - | _, _ => None - end - | Times l r => - match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with - | Some le, Some ld, Some re, Some rd => - Some (map (map (fun '(lp,rp) => lp*re + le*rp)) (combine_prod ld rd)) - | _, _, _, _ => None - end - | Divide l r => - match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with - | Some le, Some ld, Some re, Some rd => - Some (map (map (fun '(lp,rp) => (lp*re - le*rp)/(re * re))) (combine_prod ld rd)) - | _, _, _, _ => None - end - | Square e => - match df_eval σ e, df_eval_subgradient σ e lv with - | Some ee, Some ed => Some (map (map (fun pd => 2 * ee * pd)) ed) - | _, _ => None - end - | Exp e => - match df_eval σ e, df_eval_subgradient σ e lv with - | Some ee, Some ed => Some (map (map (fun pd => pd * Fexp ee)) ed) - | _, _ => None - end - | Log e => - match df_eval σ e, df_eval_subgradient σ e lv with - | Some ee, Some ed => Some (map (map (fun pd => (pd / ee))) ed) - | _, _ => None - end - | Abs e => - match df_eval σ e, df_eval_subgradient σ e lv with - | Some ee, Some ed => - if Feq ee 0 then Some (ed ++ (map (map (fun ep => -ep)) ed)) - else Some (map (map (fun ed => (ed * (sign ee)))) ed) - | _, _ => None - end - | Sign e => - match df_eval σ e with - | Some ee => Some ((map (fun _ => 0) lv) :: nil ) - | _ => None - end - | PSign e => - match df_eval σ e with - | Some ee => Some ((map (fun _ => 0) lv) :: nil ) - | _ => None - end - | Max l r => - match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with - | Some le, Some ld, Some re, Some rd => - if Feq le re then Some (ld ++ rd) - else if le > re then Some ld - else Some rd - | _, _, _, _ => None - end - | VectorElem n l i => - match (df_eval_subgradient σ l lv) with - | Some l' => Some (l' i) - | _ => None - end - | MatrixElem m n l i j => - match (df_eval_subgradient σ l lv) with - | Some l' => Some (l' i j) - | _ => None - end - | VectorSum n l => - match df_eval_subgradient σ l lv with - | Some l' => - Some (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp)) - (combine_prod a b)) - ((map (fun _ => 0) lv)::nil) l') - | _ => None - end - | VectorDot n l r => - match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with - | Some le, Some ld, Some re, Some rd => - Some (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp)) - (combine_prod a b)) - ((map (fun _ => 0) lv)::nil) - (fun i => map (map (fun '(lp,rp) => lp*(re i) + (le i)*rp)) - (combine_prod (ld i) (rd i)))) - | _, _, _, _ => None - end - | VectorScalMult n x r => - match df_eval σ x, df_eval_subgradient σ x lv, df_eval σ r, df_eval_subgradient σ r lv with - | Some xe, Some xd, Some re, Some rd => - Some (fun j => map (map (fun '(xp,rp) => xe * rp + xp * (re j))) (combine_prod xd (rd j))) - | _, _, _, _ => None - end - | MatrixScalMult n m x r => - match df_eval σ x, df_eval_subgradient σ x lv, df_eval σ r, df_eval_subgradient σ r lv with - | Some xe, Some xd, Some re, Some rd => - Some (fun i j => map (map (fun '(xp,rp) => xe * rp + xp * (re i j))) (combine_prod xd (rd i j))) - - | _, _, _, _ => None - end - | MatrixVectorMult n m l r => - match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with - | Some le, Some ld, Some re, Some rd => - Some (fun i => - (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp)) - (combine_prod a b)) - ((map (fun _ => 0) lv)::nil) - (fun j => map (map (fun '(lp,rp) => lp*(re j) + (le i j)*rp)) - (combine_prod (ld i j) (rd j))))) - | _, _, _, _ => None - end - | MatrixMult n m p l r => - match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with - | Some le, Some ld, Some re, Some rd => - Some (fun i k => - (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp)) - (combine_prod a b)) - ((map (fun _ => 0) lv)::nil) - (fun j => map (map (fun '(lp,rp) => lp*(re j k) + (le i j)*rp)) - (combine_prod (ld i j) (rd j k))))) - | _, _, _, _ => None - end - | VectorPlus n l r => - match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with - | Some ld, Some rd => Some (fun i => (map (map (fun '(x, y) => x+y)) (combine_prod (ld i) (rd i)))) - | _, _ => None - end - | VectorMinus n l r => - match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with - | Some ld, Some rd => Some (fun i => (map (map (fun '(x, y) => x-y)) (combine_prod (ld i) (rd i)))) - | _, _ => None - end - | MatrixPlus n m l r => - match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with - | Some ld, Some rd => Some (fun i j => (map (map (fun '(x, y) => x+y)) (combine_prod (ld i j) (rd i j)))) - | _, _ => None - end - | MatrixMinus n m l r => - match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with - | Some ld, Some rd => Some (fun i j => (map (map (fun '(x, y) => x-y)) (combine_prod (ld i j) (rd i j)))) - | _, _ => None - end - | VectorApply n x s r => - match df_eval σ r, df_eval_subgradient σ r lv with - | Some re, Some rd => - vectoro_to_ovector - (fun i => match df_eval_subgradient (cons (x, re i) σ) s lv with - | Some sd => - Some (map (map (fun '(x, y) => x*y)) (combine_prod (rd i) sd)) - | _ => None - end) - | _, _ => None - end - | Lossfun n v1 v2 s l r => - match df_eval σ l, df_eval_subgradient σ l lv with - | Some le, Some ld => - match (vectoro_to_ovector - (fun i => match df_eval_subgradient (cons (v1, (le i)) (cons (v2, r i) σ)) s lv with - | Some sd => Some (map (map (fun '(x, y) => x*y)) (combine_prod (ld i) sd)) - | _ => None - end)) with - | Some vv => Some (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp)) - (combine_prod a b)) - ((map (fun _ => 0) lv)::nil) vv) - | _ => None - end - | _, _ => None - end - end). -*) - End deriv2. - - Definition dft_one (dft:definition_function_types) : definition_function_types_interp dft - := match dft with - | DTfloat => 1 - | DTVector n => fun _ => 1 - | DTMatrix m n => fun _ _ => 1 - end. - - Section scalar_ind. - - Fixpoint is_scalar_function_ind_gen {Ann} - {P:forall {T}, DefinedFunction Ann T->Prop} - (fnumber:forall ann x, P (Number ann x)) - (fconstant:forall (ann:Ann DTfloat) x, P (@Constant _ DTfloat ann x)) - (fvar:forall sv ann, P (@Var _ (sv,DTfloat) ann)) - (fplus:forall a l r, P l -> P r -> P (Plus a l r)) - (fminus:forall a l r, P l -> P r -> P (Minus a l r)) - (ftimes:forall a l r, P l -> P r -> P (Times a l r)) - (fdivide:forall a l r, P l -> P r -> P (Divide a l r)) - (fsquare:forall a e, P e -> P (Square a e)) - (fexp:forall a e, P e -> P (Exp a e)) - (flog:forall a e, P e -> P (Log a e)) - (fabs:forall a e, P e -> P (Abs a e)) - (fsign:forall a e, P e -> P (Sign a e)) - (fpsign:forall a e, P e -> P (PSign a e)) - (fmax:forall a l r, P l -> P r -> P (Max a l r)) - {T} - (df:DefinedFunction Ann T) {struct df} : is_scalar_function df -> P df. - Proof. - induction df; simpl; intros isc; try tauto. - - apply fnumber. - - destruct t; simpl in isc; try tauto. - apply fconstant. - - destruct v. - destruct d; simpl in isc; try tauto. - apply fvar. - - apply fplus. - + apply IHdf1; tauto. - + apply IHdf2; tauto. - - apply fminus. - + apply IHdf1; tauto. - + apply IHdf2; tauto. - - apply ftimes. - + apply IHdf1; tauto. - + apply IHdf2; tauto. - - apply fdivide. - + apply IHdf1; tauto. - + apply IHdf2; tauto. - - apply fsquare. - + apply IHdf; tauto. - - apply fexp. - + apply IHdf; tauto. - - apply flog. - + apply IHdf; tauto. - - apply fabs. - + apply IHdf; tauto. - - apply fsign. - + apply IHdf; tauto. - - apply fpsign. - + apply IHdf; tauto. - - apply fmax. - + apply IHdf1; tauto. - + apply IHdf2; tauto. - Qed. - - Definition is_scalar_function_ind {Ann} - {P:DefinedFunction Ann DTfloat->Prop} - (fnumber:forall ann x, P (Number ann x)) - (fconstant:forall (ann:Ann DTfloat) x, P (@Constant _ DTfloat ann x)) - (fvar:forall sv ann, P (@Var _ (sv,DTfloat) ann)) - (fplus:forall a l r, P l -> P r -> P (Plus a l r)) - (fminus:forall a l r, P l -> P r -> P (Minus a l r)) - (ftimes:forall a l r, P l -> P r -> P (Times a l r)) - (fdivide:forall a l r, P l -> P r -> P (Divide a l r)) - (fsquare:forall a e, P e -> P (Square a e)) - (fexp:forall a e, P e -> P (Exp a e)) - (flog:forall a e, P e -> P (Log a e)) - (fabs:forall a e, P e -> P (Abs a e)) - (fsign:forall a e, P e -> P (Sign a e)) - (fpsign:forall a e, P e -> P (PSign a e)) - (fmax:forall a l r, P l -> P r -> P (Max a l r)) - (df:DefinedFunction Ann DTfloat) : is_scalar_function df -> P df. - Proof. - apply (@is_scalar_function_ind_gen _ (fun t => match t with - | DTfloat => fun df => P df - | _ => fun _ => False - end)); trivial. - Qed. - - Definition vartlookup_eq (l1 l2:df_env) : Prop := forall a, vartlookup l1 a = vartlookup l2 a. - - Global Instance vartlookup_eq_equiv : Equivalence vartlookup_eq. - Proof. - unfold vartlookup_eq. - constructor; red. - - intros; reflexivity. - - intros; eauto. - - intro; etransitivity; eauto. - Qed. - - End scalar_ind. - - Lemma lookup_update (xv : var_type) (gradenv : df_env) - (val : definition_function_types_interp (snd xv)) : - vartlookup (vart_update gradenv xv val) xv = Some val. - Proof. - induction gradenv; simpl. - - destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence]. - refl_simpler; simpl; trivial. - - destruct a; simpl. - case_eq (@equiv_dec var_type _ _ _ xv x); simpl; intros. - + destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence]. - refl_simpler; simpl; trivial. - + rewrite H; trivial. - Qed. - - Lemma lookup_update_neq (xv1 xv2 : var_type) (gradenv : df_env) - (val : definition_function_types_interp (snd xv1)) : xv1 <> xv2 -> - vartlookup (vart_update gradenv xv1 val) xv2 = vartlookup gradenv xv2. - Proof. - intros neq. - induction gradenv; simpl. - - destruct (@equiv_dec var_type _ _ _ xv2 xv1); congruence. - - destruct a; simpl. - case_eq (@equiv_dec var_type _ _ _ xv1 x); simpl; intros. - + destruct (@equiv_dec var_type _ _ _ xv2 xv1); [congruence | ]. - destruct (@equiv_dec var_type _ _ _ xv2 x); congruence. - + destruct (@equiv_dec var_type _ _ _ xv2 x); congruence. - Qed. - - Lemma lookup_update2 (xv : var_type) (gradenv : df_env) - (val : definition_function_types_interp (snd xv)) : - vartlookup ((mk_env_entry xv val) :: gradenv) xv = Some val. - Proof. - induction gradenv; simpl. - - destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence]. - refl_simpler; simpl; trivial. - - destruct a; simpl. - case_eq (@equiv_dec var_type _ _ _ xv x); simpl; intros. - + destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence]. - refl_simpler; simpl; trivial. - + destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence]. - refl_simpler; simpl; trivial. - Qed. - - -Tactic Notation "DefinedFunction_scalar_cases" tactic(first) ident(c) := - first; - [ Case_aux c "Number"%string - | Case_aux c "Constant"%string - | Case_aux c "Var"%string - | Case_aux c "Plus"%string - | Case_aux c "Minus"%string - | Case_aux c "Times"%string - | Case_aux c "Divide"%string - | Case_aux c "Square"%string - | Case_aux c "Exp"%string - | Case_aux c "Log"%string - | Case_aux c "Abs"%string - | Case_aux c "Sign"%string - | Case_aux c "PSign"%string - | Case_aux c "Max"%string]. - - Lemma df_eval_backprop_deriv_preserves_lookup_not_none {Ann T} {env} {grad gradenv d} {df:DefinedFunction Ann T} : - df_eval_backprop_deriv env df gradenv grad = Some d -> - forall xv, - vartlookup gradenv xv <> None -> - vartlookup d xv <> None. - Proof. - simpl. - revert grad gradenv d. - DefinedFunction_cases (induction df) Case; simpl. - - Case "Number"%string; intros; inversion H; subst; easy. - - Case "Constant"%string; intros; inversion H; subst; easy. - - Case "DVector"%string. - intros grad. - unfold two_vector_env_iter_alt. - induction (bounded_seq0 n). - simpl. - intros. - inversion H0; subst; trivial. - simpl. - intros gradenv d. - case_eq (df_eval_backprop_deriv env (x a) gradenv (grad a)). - intros. - specialize (H a (grad a) gradenv d0). - specialize (IHl d0 d). - apply IHl; trivial. - apply H; trivial. - intros. - assert (list_env_iter - (fun (i : {n' : nat | (n' < n)%nat}) (env0 : df_env) => - df_eval_backprop_deriv env (x i) env0 (grad i)) None l = None) - by apply list_env_iter_none. - intros; rewrite H1 in H3; discriminate. - - Case "DMatrix"%string. - intros grad. - unfold two_matrix_env_iter_alt. - induction (bounded_seq0 n); simpl. - { intros; inversion H0; subst; trivial. } - intros gradenv d eqq. - case_eq ((list_env_iter - (fun (j : {m' : nat | (m' < m)%nat}) (env0 : df_env) => - df_eval_backprop_deriv env (x a j) env0 (grad a j)) (Some gradenv) - (bounded_seq0 m))) - ; [ intros dd ddeqq | intros ddeqq] - ; rewrite ddeqq in eqq - ; simpl in eqq - ; [| destruct l; simpl; discriminate]. - specialize (IHl _ _ eqq). - cut (forall xv : var_type, vartlookup gradenv xv <> None -> vartlookup dd xv <> None) - ; [ eauto | ]. - clear d IHl eqq. - revert gradenv dd ddeqq. - induction (bounded_seq0 m); simpl - ; intros gradenv dd ddeqq - ; simpl in ddeqq. - { inversion ddeqq; subst; trivial. } - case_eq (df_eval_backprop_deriv env (x a a0) gradenv (grad a a0)) - ; [intros dd2 ddeqq2 | intros ddeqq2] - ; rewrite ddeqq2 in ddeqq - ; simpl in ddeqq - ; [| destruct l0; simpl; discriminate]. - eauto. - - Case "Var"%string. - intros. - destruct (vartlookup gradenv v) ; [|congruence]. - intros. - inversion H. - destruct (vart_dec v xv). - + subst; rewrite lookup_update. - discriminate. - + rewrite lookup_update_neq; trivial. - - Case "Plus"%string. - intros grad gradenv. - case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence]. - intros. - specialize (IHdf1 grad gradenv d). - specialize (IHdf2 grad d d0). - apply IHdf2; trivial. - apply IHdf1; trivial. - - Case "Minus"%string. - intros grad gradenv. - case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence]. - intros. - specialize (IHdf1 grad gradenv d). - specialize (IHdf2 (-grad) d d0). - apply IHdf2; trivial. - apply IHdf1; trivial. - - Case "Times"%string. - intros grad gradenv d. - destruct (df_eval env df1) ; [|congruence]. - destruct (df_eval env df2) ; [|congruence]. - case_eq (df_eval_backprop_deriv env df1 gradenv (d1 * grad)) ; [|congruence]. - intros d2. - specialize (IHdf1 (d1 * grad) gradenv d2). - specialize (IHdf2 (d0 * grad) d2 d). - intros. - apply IHdf2; trivial. - apply IHdf1; trivial. - - Case "Divide"%string. - intros grad gradenv d. - destruct (df_eval env df1) ; [|congruence]. - destruct (df_eval env df2) ; [|congruence]. - case_eq (df_eval_backprop_deriv env df1 gradenv (grad / d1)) ; [|congruence]. - intros d2. - specialize (IHdf1 (grad / d1) gradenv d2). - specialize (IHdf2 (- d0 / (d1 * d1) * grad) d2 d). - intros. - apply IHdf2; trivial. - apply IHdf1; trivial. - - Case "Square"%string; intros. - destruct (df_eval env df) ; [|congruence]. - specialize (IHdf (2 * d0 * grad) gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "Exp"%string; intros. - destruct (df_eval env df) ; [|congruence]. - specialize (IHdf (grad * Fexp d0) gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "Log"%string; intros. - destruct (df_eval env df) ; [|congruence]. - specialize (IHdf (grad / d0) gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "Abs"%string; intros. - destruct (df_eval env df) ; [|congruence]. - specialize (IHdf (grad * sign d0) gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "Sign"%string; intros. - specialize (IHdf 0 gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "PSign"%string; intros. - specialize (IHdf 0 gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "Max"%string; intros. - destruct (df_eval env df1) ; [|congruence]. - destruct (df_eval env df2) ; [|congruence]. - destruct (d0 <= d1). - specialize (IHdf2 grad gradenv d). - apply IHdf2. - apply H. - trivial. - specialize (IHdf1 grad gradenv d). - apply IHdf1. - apply H. - trivial. - - Case "VectorDot"%string. - intros grad gradenv d. - destruct (df_eval env df1) ; [|congruence]. - destruct (df_eval env df2) ; [|congruence]. - case_eq (df_eval_backprop_deriv env df1 gradenv - (vmap (fun rv : float => rv * grad) d1)); [|congruence]. - intros. - specialize (IHdf1 (vmap (fun rv : float => rv * grad) d1) gradenv d2). - specialize (IHdf2 (vmap (fun lv : float => lv * grad) d0) d2 d). - apply IHdf2. - apply H0. - apply IHdf1. - apply H. - apply H1. - - Case "VectorSum"%string. - intros. - specialize (IHdf (ConstVector n grad) gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "MatrixSum"%string. - intros. - specialize (IHdf (ConstMatrix m n grad) gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "VectorElem"%string. - intros. - specialize (IHdf (fun k : {n' : nat | (n' < n)%nat} => - if equiv_dec (proj1_sig k) (proj1_sig i) - then grad else 0) gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "MatrixElem"%string. - intros. - specialize (IHdf (fun (k1 : {n' : nat | (n' < m)%nat}) - (k2 : {m' : nat | (m' < n)%nat}) => - if equiv_dec (proj1_sig k1) (proj1_sig i) - then if equiv_dec (proj1_sig k2) (proj1_sig j) then grad else 0 - else 0) gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "MatrixVectorMult"%string. - intros grad gradenv d. - destruct (df_eval env df1) ; [|congruence]. - destruct (df_eval env df2) ; [|congruence]. - case_eq (df_eval_backprop_deriv env df1 gradenv - (fun (i : {n' : nat | (n' < m)%nat}) - (j : {m' : nat | (m' < n)%nat}) => grad i * d1 j)) ; [|congruence]. - intros d2. - specialize (IHdf1 (fun (i : {n' : nat | (n' < m)%nat}) - (j : {m' : nat | (m' < n)%nat}) => grad i * d1 j) gradenv d2). - specialize (IHdf2 (matrix_vector_mult - (fun (i : {n' : nat | (n' < n)%nat}) - (j : {m' : nat | (m' < m)%nat}) => d0 j i) grad) d2 d). - intros. - apply IHdf2; trivial. - apply IHdf1; trivial. - - Case "MatrixVectorAdd"%string. - intros grad gradenv d. - case_eq (df_eval_backprop_deriv env df1 gradenv grad); [|congruence]. - intros d0 casedf1. - specialize (IHdf1 _ _ _ casedf1). - clear casedf1. - revert gradenv d d0 IHdf1. - induction (bounded_seq0 n). - + simpl. - intros. - inversion H; subst. - eauto. - + intros gradenv d d0 d0eqq. - simpl. - case_eq (df_eval_backprop_deriv env df2 d0 (transpose grad a)); simpl - ; [intros ? eqq1 | intros eqq1]. - * intros. - { apply (IHl d0 _ d1); trivial. - - eapply IHdf2; eauto. - - eauto. - } - * destruct l; simpl; discriminate. - - Case "MatrixMult"%string. - intros grad gradenv d. - destruct (df_eval env df1) ; [|congruence]. - destruct (df_eval env df2) ; [|congruence]. - case_eq (df_eval_backprop_deriv env df1 gradenv - (matrix_mult grad - (fun (i : {n' : nat | (n' < n)%nat}) - (j : {m' : nat | (m' < p)%nat}) => d1 j i))) - ; [|congruence]. - intros d2. - specialize (IHdf1 (matrix_mult grad - (fun (i : {n' : nat | (n' < n)%nat}) - (j : {m' : nat | (m' < p)%nat}) => d1 j i)) - gradenv d2). - specialize (IHdf2 (matrix_mult - (fun (i : {n' : nat | (n' < p)%nat}) - (j : {m' : nat | (m' < m)%nat}) => d0 j i) grad) d2 d). - intros. - apply IHdf2; trivial. - apply IHdf1; trivial. - - Case "VectorPlus"%string. - intros grad gradenv d. - case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence]. - intros. - specialize (IHdf1 grad gradenv d0). - specialize (IHdf2 grad d0 d). - apply IHdf2. - apply H0. - apply IHdf1. - apply H. - apply H1. - - Case "VectorMinus"%string. - intros grad gradenv d. - case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence]. - intros. - specialize (IHdf1 grad gradenv d0). - specialize (IHdf2 (fun i : {n' : nat | (n' < n)%nat} => - grad i) d0 d). - apply IHdf2. - apply H0. - apply IHdf1. - apply H. - apply H1. - - Case "MatrixPlus"%string. - intros grad gradenv d. - case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence]. - intros. - specialize (IHdf1 grad gradenv d0). - specialize (IHdf2 grad d0 d). - apply IHdf2. - apply H0. - apply IHdf1. - apply H. - apply H1. - - Case "MatrixMinus"%string. - intros grad gradenv d. - case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence]. - intros. - specialize (IHdf1 grad gradenv d0). - specialize (IHdf2 (fun (i : {n' : nat | (n' < n)%nat}) - (j : {m' : nat | (m' < m)%nat}) => - grad i j) - d0 d). - apply IHdf2. - apply H0. - apply IHdf1. - apply H. - apply H1. - - Case "VectorScalMult"%string. - intros grad gradenv d. - destruct (df_eval env df1) ; [|congruence]. - destruct (df_eval env df2) ; [|congruence]. - case_eq (df_eval_backprop_deriv env df1 gradenv - (vsum (fun j : {n' : nat | (n' < n)%nat} => d1 j * grad j))) - ; [|congruence]. - intros d2. - specialize (IHdf1 (vsum (fun j : {n' : nat | (n' < n)%nat} => d1 j * grad j)) - gradenv d2). - specialize (IHdf2 (fun j : {n' : nat | (n' < n)%nat} => d0 * grad j) - d2 d). - intros. - apply IHdf2; trivial. - apply IHdf1; trivial. - - Case "MatrixScalMult"%string. - intros grad gradenv d. - destruct (df_eval env df1) ; [|congruence]. - destruct (df_eval env df2) ; [|congruence]. - case_eq (df_eval_backprop_deriv env df1 gradenv - (msum - (fun (i : {n' : nat | (n' < n)%nat}) - (j : {m' : nat | (m' < m)%nat}) => - d1 i j * grad i j))) - ; [|congruence]. - intros d2. - specialize (IHdf1 (msum - (fun (i : {n' : nat | (n' < n)%nat}) - (j : {m' : nat | (m' < m)%nat}) => - d1 i j * grad i j)) - gradenv d2). - specialize (IHdf2 (fun (i : {n' : nat | (n' < n)%nat}) - (j : {m' : nat | (m' < m)%nat}) => grad i j * d0) - d2 d). - intros. - apply IHdf2; trivial. - apply IHdf1; trivial. - - Case "VectorApply"%string. - intros grad gradenv d. - destruct (df_eval env df2) ; [|congruence]. - simpl in *. - match_destr; simpl; eauto. - - Case "MatrixApply"%string. - intros grad gradenv d. - destruct (df_eval env df2) ; [|congruence]. - match_destr; simpl; eauto. - - Case "VLossfun"%string. - intros grad gradenv d. - destruct (df_eval env df2) ; [|congruence]. - match_destr; simpl; eauto. - - Case "MLossfun"%string. - intros grad gradenv d. - destruct (df_eval env df2) ; [|congruence]. - match_destr; simpl; eauto. - Qed. - - Definition df_eval_deriv_gen_top {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v: var_type) : - option (lifted_type (definition_function_types_interp T) (snd v)) := - match (snd v) as vt return option (lifted_type (definition_function_types_interp T) vt) with - | DTfloat => df_eval_deriv_genvar σ df ((mk_env_entry (fst v, DTfloat) 1)::nil) - | DTVector n => - vectoro_to_ovector - (fun i => df_eval_deriv_genvar σ df ((mk_env_entry (fst v, DTVector n) (UnitVector n i))::nil)) - | DTMatrix n m => - matrixo_to_omatrix - (fun i j => df_eval_deriv_genvar σ df ((mk_env_entry (fst v, DTMatrix n m) (UnitMatrix n m i j))::nil)) - end. - - Program Definition subvar (x : var_type) (grad_env:df_env) := - (match snd x as y return snd x = y -> - definition_function_types_interp y -> - definition_function_types_interp y with - | DTfloat => fun pf grad => match vartlookup grad_env x with - | Some val => ((coerce _ val):float) - grad - | _ => Fopp grad - end - | DTVector n => fun pf grad => match vartlookup grad_env x with - | Some val => fun i => (((coerce _ val):Vector float n) i) - (grad i) - | _ => vmap Fopp grad - end - | DTMatrix m n => fun pf grad => match vartlookup grad_env x with - | Some val => fun i j => (((coerce _ val):Matrix float m n) i j) - (grad i j) - | _ => vmap (vmap Fopp) grad - end - end) (eq_refl _). - Next Obligation. - rewrite pf; reflexivity. - Defined. - Next Obligation. - rewrite pf; reflexivity. - Defined. - Next Obligation. - rewrite pf; reflexivity. - Defined. - - Definition df_eval_backprop_delta {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v: var_type) (grad_env:df_env) (grad: definition_function_types_interp T) : - option (definition_function_types_interp (snd v)) := - match vartlookup grad_env v with - | Some old => - lift (fun e => subvar v e old) (df_eval_backprop_deriv σ df grad_env grad) - | None => None - end. - - Program Definition df_eval_backward_gen_top {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v: var_type) (grad_env:df_env) : - option (lifted_type (definition_function_types_interp (snd v)) T) := - match vartlookup grad_env v with - | Some old => - (match T as vt return T = vt -> option (lifted_type (definition_function_types_interp (snd v)) vt) with - | DTfloat => fun pf => (lift (fun e => subvar v e old) (df_eval_backprop_deriv σ df grad_env (coerce _ 1))) - | DTVector n => fun pf => - vectoro_to_ovector - (fun i => lift (fun e => subvar v e old) (df_eval_backprop_deriv σ df grad_env (coerce _ (UnitVector n i)))) - | DTMatrix m n => fun pf => - matrixo_to_omatrix - (fun i j => lift (fun e => subvar v e old) (df_eval_backprop_deriv σ df grad_env (coerce _ (UnitMatrix m n i j)))) - end) (eq_refl _) - | None => None - end. - - Definition transpose_lifted_type {T1 T2} : - lifted_type (definition_function_types_interp T1) T2 -> - lifted_type (definition_function_types_interp T2) T1 - := match T1, T2 with - | DTfloat, _ => fun inp => inp - | _, DTfloat => fun inp => inp - | DTVector n1, DTVector n2 => fun inp => fun i j => inp j i - | DTMatrix m1 n1, DTMatrix m2 n2 => fun inp => fun i j p q => inp p q i j - | DTVector n1, DTMatrix m2 n2 => fun inp => fun i p q => inp p q i - | DTMatrix m1 n1, DTVector n2 => fun inp => fun i j p => inp p i j - end. - Section deriv_deriv. - End deriv_deriv. - - Section max_derived. - Definition MaxDerived (a b : DefinedFunction UnitAnn DTfloat) := - Divide tt (Plus tt (Plus tt (Abs tt (Minus tt b a)) b) a) (Number tt 2). - - Delimit Scope df_scope with df. - - Notation "x + y" := (Plus x y) (only printing) : df_scope. - Notation "x - y" := (Minus x y) (only printing) : df_scope. - Notation "x / y" := (Divide x y) (only printing) : df_scope. - Notation "x * y" := (Times x y) (only printing) : df_scope. - Notation "x" := (Number x) (only printing, at level 0) : df_scope. - Notation "x" := (Var x) (only printing, at level 0) : df_scope. - Notation "'|' x '|'" := (Abs x) (only printing, at level 0) : df_scope. - - End max_derived. - - Section fv. - - Fixpoint df_free_variables {Ann} {T} (f : DefinedFunction Ann T) : list var_type - := match f with - | Number _ x => nil - | DVector n _ x => vlconcat_map df_free_variables x - | Constant t _ x => nil - | DMatrix n m _ x => vlconcat_map (fun a => vlconcat_map df_free_variables a) x - | Var v _ => v::nil - | Plus _ l r => (df_free_variables l) ++ (df_free_variables r) - | Minus _ l r => (df_free_variables l) ++ (df_free_variables r) - | Times _ l r => (df_free_variables l) ++ (df_free_variables r) - | Divide _ l r => (df_free_variables l) ++ (df_free_variables r) - | Max _ l r => (df_free_variables l) ++ (df_free_variables r) - | Abs _ e => df_free_variables e - | Sign _ e => df_free_variables e - | PSign _ e => df_free_variables e - | Log _ e => df_free_variables e - | Square _ e => df_free_variables e - | Exp _ e => df_free_variables e - | VectorElem n _ l i => df_free_variables l - | MatrixElem m n _ l i j => df_free_variables l - | VectorDot n _ l r => (df_free_variables l) ++ (df_free_variables r) - | VectorSum n _ l => df_free_variables l - | MatrixSum n m _ l => df_free_variables l - | VectorScalMult n _ x r => (df_free_variables x) ++ (df_free_variables r) - | MatrixScalMult n m _ x r => (df_free_variables x) ++ (df_free_variables r) - | MatrixVectorMult n m _ l r => (df_free_variables l) ++ (df_free_variables r) - | MatrixVectorAdd n m _ l r => (df_free_variables l) ++ (df_free_variables r) - | MatrixMult n m p _ l r => (df_free_variables l) ++ (df_free_variables r) - | VectorPlus n _ l r => (df_free_variables l) ++ (df_free_variables r) - | VectorMinus n _ l r => (df_free_variables l) ++ (df_free_variables r) - | MatrixPlus n m _ l r => (df_free_variables l) ++ (df_free_variables r) - | MatrixMinus n m _ l r => (df_free_variables l) ++ (df_free_variables r) - | VectorApply n _ x s l => (remove_all (x,DTfloat) (df_free_variables s)) - ++ (df_free_variables l) - | MatrixApply n m _ x s l => (remove_all (x,DTfloat) (df_free_variables s)) - ++ (df_free_variables l) - | VLossfun n _ v1 v2 s l r => (remove_all (v1,DTfloat) (remove_all (v2,DTfloat) (df_free_variables s))) - ++ (df_free_variables l) - | MLossfun n m _ v1 v2 s l r => (remove_all (v1,DTfloat) (remove_all (v2,DTfloat) (df_free_variables s))) - ++ (df_free_variables l) - end. - - Definition df_closed {Ann} {T} (f: DefinedFunction Ann T) : Prop - := match df_free_variables f with - | nil => True - | _ => False - end. - - Lemma df_closed_nil {T} (f: DefinedFunction UnitAnn T) : df_closed f -> df_free_variables f = nil. - Proof. - unfold df_closed. - destruct (df_free_variables f); tauto. - Qed. - - Definition df_closed_over {Ann} {T} (f : DefinedFunction Ann T) (vl : list var_type) : Prop - := incl (df_free_variables f) vl. - - Fixpoint fully_closed_over {Ann} {T} (df : DefinedFunction Ann T) (vl : list var_type) : Prop - := - match df with - | Number _ x => True - | DVector n _ x => vforall (fun f => fully_closed_over f vl) x - | Constant t _ x => True - | DMatrix n m _ x => vforall (fun row => - (vforall (fun f => fully_closed_over f vl) row)) x - | Var v _ => In v vl - | Plus _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | Minus _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | Times _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | Divide _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | Max _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | Abs _ e => fully_closed_over e vl - | Sign _ e => fully_closed_over e vl - | PSign _ e => fully_closed_over e vl - | Log _ e => fully_closed_over e vl - | Square _ e => fully_closed_over e vl - | Exp _ e => fully_closed_over e vl - | VectorElem n _ l i => fully_closed_over l vl - | MatrixElem m n _ l i j => fully_closed_over l vl - | VectorDot n _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | VectorSum n _ l => fully_closed_over l vl - | MatrixSum n m _ l => fully_closed_over l vl - | VectorScalMult n _ x r => (fully_closed_over x vl) /\ (fully_closed_over r vl) - | MatrixScalMult n m _ x r => (fully_closed_over x vl) /\ (fully_closed_over r vl) - | MatrixVectorMult n m _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | MatrixVectorAdd n m _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | MatrixMult n m p _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | VectorPlus n _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | VectorMinus n _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | MatrixPlus n m _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | MatrixMinus n m _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | VectorApply n _ x s l => (fully_closed_over s ((x,DTfloat)::nil)) /\ - (fully_closed_over l vl) - | MatrixApply n m _ x s l => (fully_closed_over s ((x,DTfloat)::nil)) /\ - (fully_closed_over l vl) - | VLossfun n _ v1 v2 s l r => (fully_closed_over s ((v1,DTfloat)::(v2,DTfloat)::nil)) - /\ (fully_closed_over l vl) - | MLossfun n m _ v1 v2 s l r => (fully_closed_over s ((v1,DTfloat)::(v2,DTfloat)::nil)) - /\ (fully_closed_over l vl) - end. - - Definition In_compat_map (f : list var_type -> list var_type) : Prop := - forall (v : var_type) (vl : list var_type), - In v vl -> In v (f vl). - - Definition map_tl (f : list var_type -> list var_type) (vl : list var_type) := - match vl with - | a :: vl1 => a :: f vl1 - | _ => f vl - end. - - Lemma In_compat_map_tl (f : list var_type -> list var_type) : - In_compat_map f -> In_compat_map (map_tl f). - Proof. - unfold In_compat_map; intros. - destruct vl. - + now simpl. - + simpl in *. - destruct H0. - * now left. - * right; now apply H. - Qed. - - Lemma fully_closed_over_map {T} (df : DefinedFunction UnitAnn T) (vl : list var_type) (f : list var_type -> list var_type) : - In_compat_map f -> fully_closed_over df vl -> fully_closed_over df (f vl). - Proof. - revert f; revert vl. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case - ; simpl; intros; try solve [ - trivial - | - apply IHdf; trivial - | - split; destruct H0; - [apply IHdf1; trivial - | apply IHdf2; trivial] - ]. - - Case "DVector"%string. - apply vforall_forall; intros. - apply H; trivial. - now rewrite vforall_forall in H1. - - Case "DMatrix"%string. - apply vforall_forall; intros. - apply vforall_forall; intros. - apply H; trivial. - rewrite vforall_forall in H1. - specialize (H1 i). - now rewrite vforall_forall in H1. - - now apply H. - - Case "VectorApply"%string. - split; destruct H0; trivial. - now apply IHdf2. - - Case "MatrixApply"%string. - split; destruct H0; trivial. - now apply IHdf2. - - Case "VLossfun"%string. - split; destruct H0; trivial. - now apply IHdf2. - - Case "MLossfun"%string. - split; destruct H0; trivial. - now apply IHdf2. - Qed. - - (* - Lemma closed_is_fully_closed {Ann} {T} (df : DefinedFunction Ann T) (vl : list var_type) : - df_closed_over df vl <-> fully_closed_over df vl. -*) - - -(* - Lemma df_subst_nfree {T} (e: DefinedFunction T) (v:SubVar) (e':DefinedFunction DTfloat) : - ~ In v (df_free_variables e) -> - df_subst e v e' = e. - Proof. - DefinedFunction_cases (induction e) Case; simpl; trivial; intros nin - ; try solve [try rewrite in_app_iff in nin - ; intuition congruence]. - - Case "DVector"%string. - f_equal. - apply functional_extensionality. - intros x0. - apply H. - intros inn. - apply nin. - unfold vlconcat_map, vlconcat. - apply concat_In. - exists ((df_free_variables (x x0))). - split; trivial. - apply vector_to_list_In. - - - Case "DMatrix"%string. - - - Case "Var"%string. - destruct (var_dec v0 v); intuition. - Qed. - - Lemma df_eval_complete' {T} (σ:df_env) (f:DefinedFunction T) : - incl (df_free_variables f) (domain σ) -> {v | df_eval σ f = Some v}. - Proof. - induction f; simpl; intros inc - ; try solve [rewrite <- incl_app_iff in inc - ; intuition - ; destruct X as [v1 ev1] - ; destruct X0 as [v2 ev2] - ; rewrite ev1; rewrite ev2 - ; eauto - | intuition - ; destruct X as [v1 ev1] - ; rewrite ev1 - ; eauto]. - - eauto. - - apply in_dom_lookup_strong. - specialize (inc v); simpl in *. - intuition. - Qed. - - (* This version has better computational properties *) - Lemma df_eval_complete (σ:df_env) (f:DefinedFunction) : - incl (df_free_variables f) (domain σ) -> {v | df_eval σ f = Some v}. - Proof. - case_eq (df_eval σ f); simpl. - - intros r ?? ; exists r; eauto. - - intros ? inc. - destruct (df_eval_complete' _ _ inc); congruence. - Defined. - - Lemma df_eval_none (σ:df_env) (f:DefinedFunction) : - df_eval σ f = None -> - {v | In v (df_free_variables f) /\ ~ In v (domain σ)}. - Proof. - intros. - destruct (incl_dec (df_free_variables f) (domain σ)). - - destruct (df_eval_complete _ _ i); congruence. - - apply (nincl_exists) in n; trivial. - Qed. - - (* Either we can evaluate df or we are missing a variable definition. - Note that this theorem may fail to hold if we change the definition of - division to make it partial. - *) - Lemma df_eval_compute (σ:df_env) (f:DefinedFunction) : - {v | df_eval σ f = Some v} + {x | In x (df_free_variables f) /\ ~ In x (domain σ)}. - Proof. - case_eq (df_eval σ f); simpl. - - eauto. - - intros H; apply df_eval_none in H; eauto. - Defined. - - Lemma df_eval_closed (f:DefinedFunction) : - df_closed f -> {v | df_eval nil f = Some v}. - Proof. - intros c. - apply (df_eval_complete nil f). - rewrite df_closed_nil by trivial. - simpl; reflexivity. - Defined. - - Lemma df_eval_lookup_on (σ₁ σ₂:df_env) (f:DefinedFunction) : - lookup_equiv_on (df_free_variables f) σ₁ σ₂ -> - df_eval σ₁ f = df_eval σ₂ f. - Proof. - intros lookeq. - induction f; simpl in *; trivial - ; try solve [apply lookup_equiv_on_dom_app in lookeq; intuition - ; rewrite H1, H2; trivial - | rewrite IHf; trivial]. - - apply lookeq; simpl; tauto. - Qed. -*) - End fv. - - Section apply. - - Fixpoint df_apply {T} (e: DefinedFunction UnitAnn T) - (args: forall (v:var_type), DefinedFunction UnitAnn (snd v)) : DefinedFunction UnitAnn T := - match e with - | Number _ x => Number tt x - | Constant t _ x => Constant tt x - | DVector n _ df => DVector tt (fun x => df_apply (df x) args) - | DMatrix n m _ df => DMatrix tt (fun i j => df_apply (df i j) args) - | Var v _ => args v - | Plus _ l r => Plus tt (df_apply l args) (df_apply r args) - | Times _ l r => Times tt (df_apply l args) (df_apply r args) - | Minus _ l r => Minus tt (df_apply l args) (df_apply r args) - | Divide _ l r => Divide tt (df_apply l args) (df_apply r args) - | Square _ e => Square tt (df_apply e args) - | Exp _ e => Exp tt (df_apply e args) - | Log _ e => Log tt (df_apply e args) - | Abs _ e => Abs tt (df_apply e args) - | Sign _ e => Sign tt (df_apply e args) - | PSign _ e => PSign tt (df_apply e args) - | Max _ l r => Max tt (df_apply l args) (df_apply r args) - | VectorElem n _ l i => VectorElem tt (df_apply l args) i - | MatrixElem m n _ l i j => MatrixElem tt (df_apply l args) i j - | VectorDot n _ l r => VectorDot tt (df_apply l args) (df_apply r args) - | VectorSum n _ l => VectorSum tt (df_apply l args) - | MatrixSum n m _ l => MatrixSum tt (df_apply l args) - | VectorScalMult n _ x r => VectorScalMult tt (df_apply x args) (df_apply r args) - | MatrixScalMult n m _ x r => MatrixScalMult tt (df_apply x args) (df_apply r args) - | MatrixVectorMult n m _ l r => MatrixVectorMult tt (df_apply l args) (df_apply r args) - | MatrixVectorAdd n m _ l r => MatrixVectorAdd tt (df_apply l args) (df_apply r args) - | MatrixMult n m p _ l r => MatrixMult tt (df_apply l args) (df_apply r args) - | VectorPlus n _ l r => VectorPlus tt (df_apply l args) (df_apply r args) - | VectorMinus n _ l r => VectorMinus tt (df_apply l args) (df_apply r args) - | MatrixPlus n m _ l r => MatrixPlus tt (df_apply l args) (df_apply r args) - | MatrixMinus n m _ l r => MatrixMinus tt (df_apply l args) (df_apply r args) - | VectorApply n _ x s l => VectorApply tt x (df_apply s args) (df_apply l args) - | MatrixApply n m _ x s l => MatrixApply tt x (df_apply s args) (df_apply l args) - | VLossfun n _ v1 v2 s l r => VLossfun tt v1 v2 (df_apply s args) (df_apply l args) r - | MLossfun n m _ v1 v2 s l r => MLossfun tt v1 v2 (df_apply s args) (df_apply l args) r - end. - - End apply. - -End DefinedFunctions. - -Tactic Notation "DefinedFunction_cases" tactic(first) ident(c) := - first; - [ Case_aux c "Number"%string - | Case_aux c "Constant"%string - | Case_aux c "DVector"%string - | Case_aux c "DMatrix"%string - | Case_aux c "Var"%string - | Case_aux c "Plus"%string - | Case_aux c "Minus"%string - | Case_aux c "Times"%string - | Case_aux c "Divide"%string - | Case_aux c "Square"%string - | Case_aux c "Exp"%string - | Case_aux c "Log"%string - | Case_aux c "Abs"%string - | Case_aux c "Sign"%string - | Case_aux c "PSign"%string - | Case_aux c "Max"%string - | Case_aux c "VectorDot"%string - | Case_aux c "VectorSum"%string - | Case_aux c "MatrixSum"%string - | Case_aux c "VectorElem"%string - | Case_aux c "MatrixElem"%string - | Case_aux c "MatrixVectorMult"%string - | Case_aux c "MatrixVectorAdd"%string - | Case_aux c "MatrixMult"%string - | Case_aux c "VectorPlus"%string - | Case_aux c "VectorMinus"%string - | Case_aux c "MatrixPlus"%string - | Case_aux c "MatrixMinus"%string - | Case_aux c "VectorScalMult"%string - | Case_aux c "MatrixScalMult"%string - | Case_aux c "VectorApply"%string - | Case_aux c "MatrixApply"%string - | Case_aux c "VLossfun"%string - | Case_aux c "MLossfun"%string]. - -Ltac refl_simpler := - repeat - match goal with - | [H: @eq var_type _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H) - | [H: @equiv var_type _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H) - | [H: @eq definition_function_types _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H) - | [H: @equiv definition_function_types _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H) - end. - -Section real_pfs. - - Local Existing Instance floatish_R. - Import Reals. - Import List. - - Lemma MaxDerivedMax_eq (a b : DefinedFunction UnitAnn DTfloat) : - forall σ, df_eval σ (Max tt a b) = df_eval σ (MaxDerived a b). - Proof. - simpl; intros σ. - destruct (df_eval σ a); destruct (df_eval σ b); trivial. - f_equal. - autorewrite with Rarith in *. - destruct (Rle_dec d d0). - - rewrite Rmax_right by trivial. - rewrite Rabs_pos_eq by lra. - lra. - - rewrite Rmax_left by lra. - rewrite Rabs_minus_sym. - rewrite Rabs_pos_eq by lra. - lra. - Qed. - -(* Lemma coerce_dec_id {A} (dec:forall x y:A, {x=y}+{x<>y}) (x:A) (pf:x=x) : coerce pf x = x. - Proof. - unfold coerce. - replace pf with (eq_refl A); trivial. - apply UIP_dec. - apply dec. - generalize (@UIP_dec A dec pf). - Lemma var_type_UIP_refl {x:var_type} (e:x=x) : e = eq_refl x. - Proof. - apply (UIP_dec vart_dec x x pf). - Qed. - - unfold coerce. - destruct pf. - destruct pf. - exact a. - Defined. -*) - -Tactic Notation "DefinedFunction_scalar_cases" tactic(first) ident(c) := - first; - [ Case_aux c "Number"%string - | Case_aux c "Constant"%string - | Case_aux c "Var"%string - | Case_aux c "Plus"%string - | Case_aux c "Minus"%string - | Case_aux c "Times"%string - | Case_aux c "Divide"%string - | Case_aux c "Square"%string - | Case_aux c "Exp"%string - | Case_aux c "Log"%string - | Case_aux c "Abs"%string - | Case_aux c "Sign"%string - | Case_aux c "PSign"%string - | Case_aux c "Max"%string]. - - - Lemma backpropeq_gen (x : SubVar) (env gradenv : df_env) (dfexpr : DefinedFunction UnitAnn DTfloat) (grad : float) : - let xvar := (x, DTfloat) in - is_scalar_function dfexpr -> - vartlookup gradenv (x,DTfloat) <> None -> - match df_eval_deriv env dfexpr xvar, - backprop_lookup (Some gradenv) xvar, - backprop_lookup (df_eval_backprop_deriv env dfexpr gradenv grad) xvar - with - | Some dval, Some bval0, Some bval1 => (dval*grad + bval0)%R = bval1 - | None, _, None => True - | _, _, _ => False - end. - Proof. - simpl. - intros is_scalar. - generalize is_scalar. - revert grad gradenv. - pattern dfexpr. - revert dfexpr is_scalar. - DefinedFunction_scalar_cases (apply is_scalar_function_ind) Case; simpl. - - Case "Number"%string. - intros _ _ grad gradenv _ xinn. - destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra. - - Case "Constant"%string. - intros _ _ grad gradenv _ xinn. - destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra. - - Case "Var"%string. - intros sv _ grad gradenv _ xinn. - case_eq (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]. - destruct (var_dec x sv); simpl. - + subst. - rewrite H; simpl. - rewrite lookup_update. - destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (sv, DTfloat)); [| congruence]. - unfold addvar; simpl. - rewrite H. - lra. - + destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (x, DTfloat)); [congruence | ]. - case_eq (vartlookup gradenv (sv, DTfloat)); simpl; intros. - * rewrite lookup_update_neq by congruence. - rewrite H. - lra. - * rewrite H. - lra. - - Case "Plus"%string. - intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn. - case_eq (df_eval_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl grad gradenv isc1 xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_backprop_deriv env l gradenv grad) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr grad ge' isc2). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_backprop_deriv env r ge' grad) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_backprop_deriv env l gradenv grad ); simpl; trivial; intros. - apply IHr; trivial. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl grad gradenv isc1 xinn). - case_eq (df_eval_backprop_deriv env l gradenv grad); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Minus"%string. - intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn. - case_eq (df_eval_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl grad gradenv isc1 xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_backprop_deriv env l gradenv grad) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr (- grad)%R ge' isc2). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_backprop_deriv env r ge' (- grad)%R) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_backprop_deriv env l gradenv grad ); simpl; trivial; intros. - apply IHr; trivial. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl grad gradenv isc1 xinn). - case_eq (df_eval_backprop_deriv env l gradenv grad); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Times"%string. - intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn. - case_eq (df_eval env l); - [ intros le eqle | intros eqle]; simpl; trivial. - case_eq (df_eval_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval env r); - [ intros re eqre | intros eqre] - ; simpl; trivial. - case_eq (df_eval_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl (re * grad)%R gradenv isc1 xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_backprop_deriv env l gradenv (re * grad)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr (le * grad)%R ge' isc2). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_backprop_deriv env r ge' (le * grad)%R) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_backprop_deriv env l gradenv (re * grad)%R ); simpl; trivial; intros. - apply IHr; trivial. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + case_eq (df_eval env r); - [ intros re eqre | intros eqre] - ; simpl; trivial. - specialize (IHl (re * grad)%R gradenv isc1 xinn). - case_eq (df_eval_backprop_deriv env l gradenv (re * grad)%R); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Divide"%string. - intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn. - case_eq (df_eval env l); - [ intros le eqle | intros eqle]; simpl; trivial. - case_eq (df_eval_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval env r); - [ intros re eqre | intros eqre] - ; simpl; trivial. - case_eq (df_eval_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl (grad / re)%R gradenv isc1 xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_backprop_deriv env l gradenv (grad / re)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr (- le / (re * re) * grad)%R ge' isc2). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_backprop_deriv env r ge' (- le / (re * re) * grad)%R) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_backprop_deriv env l gradenv (grad / re)%R ); simpl; trivial; intros. - apply IHr; trivial. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + case_eq (df_eval env r); - [ intros re eqre | intros eqre] - ; simpl; trivial. - specialize (IHl (grad / re)%R gradenv isc1 xinn). - case_eq (df_eval_backprop_deriv env l gradenv (grad / re)%R); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Square"%string. - intros _ e IHe grad gradenv isc xinn. - case_eq (df_eval env e); - [ intros le eqee | intros eqee]; simpl; trivial. - - specialize (IHe (2 * le * grad)%R gradenv isc xinn). - case_eq (df_eval_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_backprop_deriv env e gradenv (2 * le * grad)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - Case "Exp"%string. - intros _ e IHe grad gradenv isc xinn. - case_eq (df_eval env e); - [ intros le eqee | intros eqee]; simpl; trivial. - - specialize (IHe (grad * exp le)%R gradenv isc xinn). - case_eq (df_eval_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_backprop_deriv env e gradenv (grad * exp le)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - Case "Log"%string. - intros _ e IHe grad gradenv isc xinn. - case_eq (df_eval env e); - [ intros le eqee | intros eqee]; simpl; trivial. - - specialize (IHe (grad / le)%R gradenv isc xinn). - case_eq (df_eval_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_backprop_deriv env e gradenv (grad / le)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - Case "Abs"%string. - intros _ e IHe grad gradenv isc xinn. - case_eq (df_eval env e); - [ intros le eqee | intros eqee]; simpl; trivial. - - specialize (IHe (grad * (sign le))%R gradenv isc xinn). - case_eq (df_eval_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_backprop_deriv env e gradenv (grad * (sign le))%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - Case "Sign"%string. - intros _ e IHe grad gradenv isc xinn. - specialize (IHe 0%R gradenv isc xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (df_eval_backprop_deriv env e gradenv 0%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - simpl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - lra. - - Case "PSign"%string. - intros _ e IHe grad gradenv isc xinn. - specialize (IHe 0%R gradenv isc xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (df_eval_backprop_deriv env e gradenv 0%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - simpl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - lra. - - Case "Max"%string. - intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn. - specialize (IHl grad gradenv isc1 xinn). - specialize (IHr grad gradenv isc2 xinn). - - case_eq (df_eval env l); simpl; trivial - ; intros eld eqeld. - case_eq (df_eval env r ); simpl; intros; trivial. - destruct (Rle_dec eld d); simpl. - + destruct (df_eval_deriv env r (x, DTfloat)); simpl; trivial. - + destruct (df_eval_deriv env l (x, DTfloat)); simpl; trivial. - Qed. - - (* - - Lemma tree_backpropeq_gen (x : SubVar) (env gradenv : df_env) - (dfexpr : DefinedFunction EvalAnn DTfloat) (grad : float) : - let xvar := (x, DTfloat) in - is_scalar_function dfexpr -> - vartlookup gradenv (x,DTfloat) <> None -> - match df_eval_tree_deriv env dfexpr xvar, - backprop_lookup (Some gradenv) xvar, - backprop_lookup (df_eval_tree_backprop_deriv env dfexpr gradenv grad) xvar - with - | Some dval, Some bval0, Some bval1 => (dval*grad + bval0)%R = bval1 - | None, _, None => True - | _, _, _ => False - end. - Proof. - simpl. - intros is_scalar. - revert grad gradenv. - pattern dfexpr. - revert dfexpr is_scalar. - DefinedFunction_scalar_cases (apply is_scalar_function_ind) Case; simpl. - - - Case "Number"%string. - intros _ _ grad gradenv xinn. - destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra. - - Case "Constant"%string. - intros _ _ grad gradenv xinn. - destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra. - - Case "Var"%string. - intros sv _ grad gradenv xinn. - case_eq (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]. - destruct (var_dec x sv); simpl. - + subst. - rewrite H; simpl. - rewrite lookup_update. - destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (sv, DTfloat)); [| congruence]. - unfold addvar; simpl. - rewrite H. - lra. - + destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (x, DTfloat)); [congruence | ]. - case_eq (vartlookup gradenv (sv, DTfloat)); simpl; intros. - * rewrite lookup_update_neq by congruence. - rewrite H. - lra. - * rewrite H. - lra. - - Case "Plus"%string. - intros _ l r IHl IHr grad gradenv xinn. - case_eq (df_eval_tree_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_tree_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl grad gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env l gradenv grad) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr grad ge'). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' grad) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_tree_backprop_deriv env l gradenv grad ); simpl; trivial; intros. - apply IHr. - apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl grad gradenv xinn). - case_eq (df_eval_tree_backprop_deriv env l gradenv grad); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Minus"%string. - intros _ l r IHl IHr grad gradenv xinn. - case_eq (df_eval_tree_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_tree_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl grad gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env l gradenv grad) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr (- grad)%R ge'). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (- grad)%R) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_tree_backprop_deriv env l gradenv grad ); simpl; trivial; intros. - apply IHr. - apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl grad gradenv xinn). - case_eq (df_eval_tree_backprop_deriv env l gradenv grad); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Times"%string. - intros _ l r IHl IHr grad gradenv xinn. - case_eq (df_eval_tree_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_tree_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl (get_annotation r * grad)%R gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr (get_annotation l * grad)%R ge'). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (get_annotation l * grad)%R) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R ); simpl; trivial; intros. - apply IHr. - apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl (get_annotation r * grad)%R gradenv xinn). - case_eq (df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Divide"%string. - intros _ l r IHl IHr grad gradenv xinn. - case_eq (df_eval_tree_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_tree_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl (grad / get_annotation r)%R gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr (- get_annotation l / ((get_annotation r) * (get_annotation r)) * grad)%R ge'). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (- get_annotation l / ((get_annotation r) * (get_annotation r)) * grad)%R) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r)%R ); simpl; trivial; intros. - apply IHr. - apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl (grad / get_annotation r)%R gradenv xinn). - case_eq (df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r )%R); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Square"%string. - intros _ e IHe grad gradenv xinn. - - specialize (IHe (2 * (get_annotation e) * grad)%R gradenv xinn). - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env e gradenv (2 * (get_annotation e) * grad)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - Case "Exp"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe (grad * exp (get_annotation e))%R gradenv xinn). - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env e gradenv (grad * exp (get_annotation e))%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - - Case "Log"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe (grad / get_annotation e)%R gradenv xinn). - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env e gradenv (grad / get_annotation e)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - - Case "Abs"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe (grad * (sign (get_annotation e)))%R gradenv xinn). - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env e gradenv (grad * (sign (get_annotation e)))%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - Case "Sign"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe 0%R gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (df_eval_tree_backprop_deriv env e gradenv 0%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - simpl. - replace (de * 0)%R with (0)%R in IHe by lra. - replace (0 * grad)%R with (0)%R by lra. - apply IHe. - - Case "PSign"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe 0%R gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (df_eval_tree_backprop_deriv env e gradenv 0%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - simpl. - replace (de * 0)%R with (0)%R in IHe by lra. - replace (0 * grad)%R with (0)%R by lra. - apply IHe. - - Case "Max"%string. - intros _ l r IHl IHr grad gradenv xinn. - specialize (IHl grad gradenv xinn). - specialize (IHr grad gradenv xinn). - destruct (Rle_dec (get_annotation l) (get_annotation r)); simpl. - destruct (df_eval_tree_deriv env r (x, DTfloat)); simpl; trivial. - destruct (df_eval_tree_deriv env l (x, DTfloat)); simpl; trivial. - Qed. - *) - - Lemma eval_fully_closed_not_none {T} (σ:df_env) (df:DefinedFunction UnitAnn T) : - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> df_eval σ df <> None. - Proof. - revert σ. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case - ; simpl; intros; - try solve - [congruence - | - destruct H; simpl in *; - specialize (IHdf1 σ); specialize (IHdf2 σ); - match_option; [|tauto]; - cut_to IHdf2; - [ match_option; tauto | easy] - | - specialize (IHdf σ); simpl in IHdf; - cut_to IHdf; - [ match_option; tauto | easy] - ]. - - Case "DVector"%string. - apply vectoro_to_ovector_not_none; intro. - specialize (H i σ); simpl in H; apply H. - rewrite vforall_forall in H0. - now specialize (H0 i). - - Case "DMatrix"%string. - unfold matrixo_to_omatrix. - apply vectoro_to_ovector_not_none; intro. - apply vectoro_to_ovector_not_none; intro. - specialize (H i i0 σ); simpl in H; apply H. - rewrite vforall_forall in H0; specialize (H0 i). - rewrite vforall_forall in H0; now specialize (H0 i0). - - Case "Var"%string. - induction σ. - + simpl in H; tauto. - + simpl in *. - match_case; intros. - destruct H. - * congruence. - * now apply IHσ. - - Case "VectorApply"%string. - destruct H; simpl in *. - specialize (IHdf2 σ). - cut_to IHdf2; trivial. - match_option; [|tauto]. - apply vectoro_to_ovector_not_none; intro. - specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i) :: nil)). - now apply IHdf1. - - Case "MatrixApply"%string. - destruct H; simpl in *. - specialize (IHdf2 σ). - cut_to IHdf2; trivial. - match_option; [|tauto]. - unfold matrixo_to_omatrix. - apply vectoro_to_ovector_not_none; intro. - apply vectoro_to_ovector_not_none; intro. - specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i i0) :: nil)). - now apply IHdf1. - - Case "VLossfun"%string. - destruct H; simpl in *. - specialize (IHdf2 σ). - cut_to IHdf2; trivial. - match_option; [|tauto]. - match_option. - apply vectoro_to_ovector_not_none in eqq0. - + tauto. - + intros. - specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i) :: mk_env_entry (v2, DTfloat) (r i) :: nil)). - now apply IHdf1. - - Case "MLossfun"%string. - destruct H; simpl in *. - specialize (IHdf2 σ). - cut_to IHdf2; trivial. - match_option; [|tauto]. - unfold matrixo_to_omatrix. - match_option. - apply vectoro_to_ovector_not_none in eqq0. - + tauto. - + intros. - apply vectoro_to_ovector_not_none; intros. - specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i i0) :: mk_env_entry (v2, DTfloat) (r i i0) :: nil)). - now apply IHdf1. - Qed. - - Lemma eval_fully_closed_total {T} (σ:df_env) (df:DefinedFunction UnitAnn T) : - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> - {d:definition_function_types_interp T | df_eval σ df = Some d}. - Proof. - intros. - case_eq (df_eval σ df); intros. - - now exists d. - - generalize (eval_fully_closed_not_none σ df). - intros; simpl in *. - cut_to H1; tauto. - Qed. - - Lemma closed_over_cons {T} (df:DefinedFunction UnitAnn T) (v:var_type) (vl : list var_type): - df_closed_over df vl -> df_closed_over df (v::vl). - Proof. - unfold df_closed_over. - intros. - apply incl_tl. - apply H. - Qed. - - Lemma fully_closed_over_cons {T} (df:DefinedFunction UnitAnn T) (v:var_type) - (vl : list var_type): - fully_closed_over df vl -> fully_closed_over df (v::vl). - Proof. - intros. - apply (fully_closed_over_map df vl (fun vl1 => cons v vl1)); trivial. - unfold In_compat_map. - intros. - now apply in_cons. - Qed. - - Lemma fully_closed_over_exchange_vars {T} (df:DefinedFunction UnitAnn T) (v1 v:var_type) - (vl : list var_type): - fully_closed_over df (v1 :: v :: vl) -> fully_closed_over df (v :: v1 :: vl). - Proof. - intros. - apply (fully_closed_over_map df (v1 :: v :: vl) - (fun vl1 => match vl1 with - | a :: b :: vl2 => b :: a :: vl2 - | _ => vl1 - end )); trivial. - unfold In_compat_map. - intros. - destruct vl0; trivial. - destruct vl0; trivial. - unfold In. - unfold In in H0. - tauto. - Qed. - - Lemma fully_closed_over_singleton {T} (df:DefinedFunction UnitAnn T) (v:var_type) - (vl : list var_type): - fully_closed_over df (v::nil) -> fully_closed_over df (v::vl). - Proof. - intros. - induction vl; trivial. - apply fully_closed_over_exchange_vars. - now apply fully_closed_over_cons. - Qed. - - Lemma fully_closed_over_exchange_2vars {T} (df:DefinedFunction UnitAnn T) - (v1 v2 v:var_type) (vl : list var_type): - fully_closed_over df (v1 :: v2 :: v:: vl) -> fully_closed_over df (v :: v1 :: v2 :: vl). - Proof. - intros. - apply (fully_closed_over_map df (v1 :: v2 :: v :: vl) - (fun vl1 => match vl1 with - | a :: b :: c :: vl2 => c :: a :: b :: vl2 - | _ => vl1 - end )); trivial. - unfold In_compat_map. - intros. - destruct vl0; trivial. - destruct vl0; trivial. - destruct vl0; trivial. - unfold In. - unfold In in H0. - tauto. - Qed. - - Lemma fully_closed_over_pair {T} (df:DefinedFunction UnitAnn T) (v1 v2:var_type) - (vl : list var_type): - fully_closed_over df (v1::v2::nil) -> fully_closed_over df (v1::v2::vl). - Proof. - intros. - induction vl; trivial. - apply fully_closed_over_exchange_2vars. - apply fully_closed_over_exchange_2vars. - now apply fully_closed_over_cons. - Qed. - - Lemma fully_closed_subst {T} (vl:list var_type) (df:DefinedFunction UnitAnn T) (v:var_type) - (e':DefinedFunction UnitAnn (snd v)): - fully_closed_over df (v::vl) -> - fully_closed_over e' vl -> - fully_closed_over (df_subst df v e') vl. - Proof. - revert vl. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case - ; simpl; intros; try solve - [easy - | - apply IHdf; [apply H | apply H0] - | - split; destruct H; simpl in *; - [ apply IHdf1; [apply H | apply H0] - | apply IHdf2; [apply H1 | apply H0]] - ]. - - Case "DVector"%string. - apply vforall_forall; intros; simpl in *. - specialize (H i); apply H. - rewrite vforall_forall in H0. - specialize (H0 i). - apply H0. - apply H1. - - Case "DMatrix"%string. - apply vforall_forall; intros. - apply vforall_forall; intros; simpl in *. - specialize (H i i0); apply H. - rewrite vforall_forall in H0; specialize (H0 i). - rewrite vforall_forall in H0; specialize (H0 i0). - apply H0. - apply H1. - - Case "Var"%string. - unfold substvar. - destruct H. - + subst. - unfold substvar. - match_destr; [ | congruence]. - refl_simpler. - simpl; trivial. - + destruct v; destruct v0. - simpl in *. - match_destr. - red in e; subst. - simpl; trivial. - - Case "VectorApply"%string. - destruct H; split; trivial. - + apply IHdf2. - * apply H1. - * apply H0. - - Case "MatrixApply"%string. - destruct H; split; trivial. - + apply IHdf2. - * apply H1. - * apply H0. - - Case "VLossfun"%string. - destruct H; split; trivial. - + apply IHdf2. - * apply H1. - * apply H0. - - Case "MLossfun"%string. - destruct H; split; trivial. - + apply IHdf2. - * apply H1. - * apply H0. - Qed. - - Lemma fully_closed_deriv {T} (df:DefinedFunction UnitAnn T) (s:SubVar) - (vl : list var_type): - fully_closed_over df vl -> - fully_closed_over (df_deriv df (s, DTfloat)) vl. - Proof. - revert s; revert vl. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case - ; simpl; intros; try solve - [easy - | - apply IHdf,H - | - split; try easy; now apply IHdf - | - destruct H; repeat split; try easy; - [apply IHdf2, H0 | apply IHdf1, H] - | - destruct H; repeat split; try easy; - [ apply IHdf1, H | apply IHdf2, H0] - ]. - - Case "DVector"%string. - apply vforall_forall; intros. - apply H. - rewrite vforall_forall in H0. - now apply H0. - - Case "DMatrix"%string. - apply vforall_forall; intros. - apply vforall_forall; intros. - apply H. - rewrite vforall_forall in H0. - specialize (H0 i). - rewrite vforall_forall in H0. - now specialize (H0 i0). - - Case "Max"%string. - destruct H; repeat split; try easy. - apply IHdf2, H0. - apply IHdf1, H. - apply IHdf2, H0. - apply IHdf1, H. - - Case "VectorApply"%string. - apply vforall_forall; intros; simpl in *. - split; destruct H. - apply IHdf2, H0. - apply fully_closed_subst. - + apply IHdf1. - now apply fully_closed_over_singleton. - + simpl; apply H0. - - Case "MatrixApply"%string. - apply vforall_forall; intros; simpl in *. - apply vforall_forall; intros; simpl in *. - split; destruct H. - apply IHdf2, H0. - apply fully_closed_subst. - + apply IHdf1. - now apply fully_closed_over_singleton. - + simpl; apply H0. - - Case "VLossfun"%string. - intros; simpl in *. - split; destruct H. - apply IHdf2, H0. - apply vforall_forall; intros; simpl in *. - apply fully_closed_subst. - apply fully_closed_subst. - + apply IHdf1. - now apply fully_closed_over_pair. - + simpl; apply fully_closed_over_cons; apply H0. - + now simpl. - - Case "MLossfun"%string. - intros; simpl in *. - apply vforall_forall; intros; simpl in *. - apply vforall_forall; intros; simpl in *. - destruct H; split; trivial. - split. - apply IHdf2, H0. - apply fully_closed_subst. - apply fully_closed_subst. - + apply IHdf1. - now apply fully_closed_over_pair. - + simpl; apply fully_closed_over_cons; apply H0. - + now simpl. - Qed. - - Lemma list_env_iter_total_fun {A} (f : A -> df_env -> option df_env) (env : df_env) (l : list A) : - (forall (a:A) (env0: df_env), (f a env0) <> None) -> - list_env_iter f (Some env) l <> None. - Proof. - intros. - generalize env. - induction l; [simpl; congruence|]. - simpl; intros. - specialize (H a env0). - case_eq (f a env0). - - intros; apply (IHl d). - - tauto. - Qed. - - Lemma backprop_deriv_fully_closed_not_none {T} (σ:df_env) (df:DefinedFunction UnitAnn T) - (grad_env:df_env) (grad: definition_function_types_interp T): - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> df_eval_backprop_deriv σ df grad_env grad <> None. - Proof. - revert grad_env. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case - ; simpl; intros; - try solve [congruence - | - specialize (IHdf1 grad grad_env); simpl in IHdf1; destruct H; - match_option; [|tauto]; - specialize (IHdf2 grad d); simpl in IHdf2; - now apply IHdf2 - ]. - - Case "DVector"%string. - unfold two_vector_env_iter_alt. - rewrite vforall_forall in H0. - apply (list_env_iter_total_fun - (fun i env => df_eval_backprop_deriv σ (x i) env (grad i)) - grad_env (bounded_seq0 n)). - intros. - apply (H a (grad a) env0). - apply (H0 a). - - Case "DMatrix"%string. - unfold two_matrix_env_iter_alt. - rewrite vforall_forall in H0. - apply (list_env_iter_total_fun - (fun i env => - list_env_iter - (fun j env0 => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (Some env) (bounded_seq0 m)) - grad_env (bounded_seq0 n)). - intros. - apply (list_env_iter_total_fun - (fun j env => df_eval_backprop_deriv σ (x a j) env (grad a j)) - env0 (bounded_seq0 m)). - intros. - apply (H a a0 (grad a a0) env1). - specialize (H0 a). - rewrite vforall_forall in H0. - apply (H0 a0). - - Case "Var"%string. - match_destr. - - Case "Minus"%string. - specialize (IHdf1 grad grad_env); simpl in IHdf1; destruct H; - match_option; [|tauto]; - specialize (IHdf2 (-grad)%R d); simpl in IHdf2; - now apply IHdf2. - - Case "Times"%string. - destruct H. - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2. - + match_option; [|tauto]. - specialize (IHdf1 (d0 * grad)%R grad_env); simpl in IHdf1. - match_option; [|tauto]. - specialize (IHdf2 (d * grad)%R d1); simpl in IHdf2. - now apply IHdf2. - - Case "Divide"%string. - destruct H. - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2. - + match_option; [|tauto]. - specialize (IHdf1 (grad / d0)%R grad_env); simpl in IHdf1. - match_option; [|tauto]. - specialize (IHdf2 (-d / (d0 * d0) * grad)%R d1); simpl in IHdf2. - now apply IHdf2. - - Case "Square"%string. - generalize (eval_fully_closed_not_none σ df); intros; simpl in H0. - match_option; [|tauto]. - specialize (IHdf (2 * d * grad)%R grad_env); simpl in IHdf. - now apply IHdf. - - Case "Exp"%string. - generalize (eval_fully_closed_not_none σ df); intros; simpl in H0. - match_option; [|tauto]. - specialize (IHdf (grad * exp d)%R grad_env); simpl in IHdf. - now apply IHdf. - - Case "Log"%string. - generalize (eval_fully_closed_not_none σ df); intros; simpl in H0. - match_option; [|tauto]. - specialize (IHdf (grad / d)%R grad_env); simpl in IHdf. - now apply IHdf. - - Case "Abs"%string. - generalize (eval_fully_closed_not_none σ df); intros; simpl in H0. - match_option; [|tauto]. - specialize (IHdf (grad * sign d)%R grad_env); simpl in IHdf. - now apply IHdf. - - Case "Sign"%string. - specialize (IHdf (0)%R grad_env); simpl in IHdf. - now apply IHdf. - - Case "PSign"%string. - specialize (IHdf (0)%R grad_env); simpl in IHdf. - now apply IHdf. - - Case "Max"%string. - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - match_option; [|tauto]. - match_destr. - specialize (IHdf2 grad grad_env). - now apply IHdf2. - specialize (IHdf1 grad grad_env). - now apply IHdf1. - - Case "VectorDot"%string. - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - match_option; [|tauto]. - specialize (IHdf1 (vmap (fun rv : R => (rv * grad)%R) d0) grad_env); simpl in IHdf1. - match_option; [|tauto]. - specialize (IHdf2 (vmap (fun lv : R => (lv * grad)%R) d) d1). - now apply IHdf2. - - Case "VectorSum"%string. - specialize (IHdf (ConstVector n grad) grad_env). - now apply IHdf. - - Case "MatrixSum"%string. - specialize (IHdf (ConstMatrix m n grad) grad_env). - now apply IHdf. - - Case "VectorElem"%string. - specialize (IHdf (fun k : {n' : nat | n' < n} => if equiv_dec (` k) (` i) then grad else 0%R) grad_env). - now apply IHdf. - - Case "MatrixElem"%string. - specialize (IHdf (fun (k1 : {n' : nat | n' < m}) - (k2 : {m' : nat | m' < n}) => - if equiv_dec (` k1) (` i) then if - equiv_dec (` k2) (` j) then grad else 0%R else 0%R) - grad_env). - now apply IHdf. - - Case "MatrixVectorMult"%string. - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - match_option; [|tauto]. - specialize (IHdf1 (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => (grad i * d0 j)%R) - grad_env); simpl in IHdf1. - match_option; [|tauto]. - specialize (IHdf2 (matrix_vector_mult - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => d j i) - grad) - d1). - now apply IHdf2. - - Case "MatrixVectorAdd"%string. - specialize (IHdf1 grad grad_env); simpl in IHdf1. - match_option; [|tauto]. - match_option. - generalize (list_env_iter_total_fun - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (@transpose R m n grad i)) - d (bounded_seq0 n)); intros. - cut_to H0; [congruence|]. - intros; destruct H. - now apply IHdf2. - - Case "MatrixMult"%string. - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - match_option; [|tauto]. - specialize (IHdf1 (matrix_mult grad (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < p}) => d0 j i)) - grad_env); simpl in IHdf1. - match_option; [|tauto]. - specialize (IHdf2 (matrix_mult (fun (i : {n' : nat | n' < p}) - (j : {m' : nat | m' < m}) => d j i) grad) - d1). - now apply IHdf2. - - Case "VectorMinus"%string. - specialize (IHdf1 grad grad_env); simpl in IHdf1. - match_option; [|tauto]. - specialize (IHdf2 (fun i : {n' : nat | n' < n} => (- grad i)%R) d). - now apply IHdf2. - - Case "MatrixMinus"%string. - specialize (IHdf1 grad grad_env); simpl in IHdf1. - match_option; [|tauto]. - specialize (IHdf2 (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (- grad i j)%R) - d). - now apply IHdf2. - - Case "VectorScalMult"%string. - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - match_option; [|tauto]. - specialize (IHdf1 (vsum (fun j : {n' : nat | n' < n} => (d0 j * grad j)%R)) - grad_env); simpl in IHdf1. - match_option; [|tauto]. - specialize (IHdf2 (fun j : {n' : nat | n' < n} => (d * grad j)%R) d1). - now apply IHdf2. - - Case "MatrixScalMult"%string. - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - match_option; [|tauto]. - specialize (IHdf1 (msum - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (d0 i j * grad i j)%R)) - grad_env); simpl in IHdf1. - match_option; [|tauto]. - specialize (IHdf2 (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad i j * d)%R) - d1). - now apply IHdf2. - - Case "VectorApply"%string. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H0. - match_option; [|tauto]. - match_option. - + specialize (IHdf2 v0 grad_env). - now apply IHdf2. - + apply vectoro_to_ovector_exists_None in eqq0. - destruct eqq0. - rewrite vmap_nth in e; simpl in e. - destruct H. - match_option_in e. - generalize (fully_closed_deriv df1 v ((v,DTfloat):: nil)); intros. - cut_to H2; trivial. - generalize (eval_fully_closed_not_none (mk_env_entry (v, DTfloat) (d x) :: nil) - (df_deriv df1 (v, DTfloat))); intros. - simpl in H3; cut_to H3; tauto. - - Case "MatrixApply"%string. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H0. - match_option; [|tauto]. - match_option. - + specialize (IHdf2 m0 grad_env). - now apply IHdf2. - + unfold matrixo_to_omatrix in eqq0. - apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0. - apply vectoro_to_ovector_exists_None in e; destruct e. - unfold mmap in e. - rewrite vmap_nth in e; simpl in e. - rewrite vmap_nth in e; simpl in e. - destruct H. - unfold matrix_zip in e. - rewrite vmap_nth in e; simpl in e. - match_option_in e. - generalize (fully_closed_deriv df1 v ((v,DTfloat):: nil)). - intros; cut_to H2; trivial. - generalize (eval_fully_closed_not_none (mk_env_entry (v, DTfloat) (d x x0) :: nil) - (df_deriv df1 (v, DTfloat))); intros. - simpl in H3; cut_to H3; tauto. - - Case "VLossfun"%string. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H0. - match_option; [|tauto]. - match_option. - + specialize (IHdf2 v grad_env). - now apply IHdf2. - + apply vectoro_to_ovector_exists_None in eqq0. - destruct eqq0. - rewrite vmap_nth in e; simpl in e. - destruct H. - match_option_in e. - generalize (fully_closed_deriv df1 v1 ((v1,DTfloat)::(v2,DTfloat)::nil)). - intros; cut_to H2; trivial. - generalize (eval_fully_closed_not_none (mk_env_entry (v1, DTfloat) (d x) :: - mk_env_entry (v2, DTfloat) (r x) :: nil) - (df_deriv df1 (v1, DTfloat))); intros. - simpl in H3; cut_to H3; tauto. - - Case "MLossfun"%string. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H0. - match_option; [|tauto]. - match_option. - + specialize (IHdf2 m0 grad_env). - now apply IHdf2. - + unfold matrixo_to_omatrix in eqq0. - apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0. - apply vectoro_to_ovector_exists_None in e; destruct e. - unfold mmap in e. - rewrite vmap_nth in e; simpl in e. - rewrite vmap_nth in e; simpl in e. - destruct H. - unfold matrix_zip in e. - rewrite vmap_nth in e; simpl in e. - match_option_in e. - generalize (fully_closed_deriv df1 v1 ((v1,DTfloat)::(v2,DTfloat)::nil)). - intros; cut_to H2; trivial. - generalize (eval_fully_closed_not_none (mk_env_entry (v1, DTfloat) (d x x0) :: - mk_env_entry (v2, DTfloat) (r x x0) :: nil) - (df_deriv df1 (v1, DTfloat))); intros. - simpl in H3; cut_to H3; tauto. - Qed. - - Lemma backprop_deriv_fully_closed_total {T} (σ:df_env) (df:DefinedFunction UnitAnn T) - (grad_env:df_env) (grad: definition_function_types_interp T): - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> - {d:df_env | df_eval_backprop_deriv σ df grad_env grad = Some d}. - Proof. - case_eq (df_eval_backprop_deriv σ df grad_env grad); intros. - - now exists d. - - generalize (backprop_deriv_fully_closed_not_none σ df grad_env grad). - intros; simpl in *. - cut_to H1; tauto. - Qed. - - Lemma eval_deriv_fully_closed_not_none {T} (σ:df_env) (df:DefinedFunction UnitAnn T) - (v:var_type): - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> df_eval_deriv σ df v <> None. - Proof. - revert σ v. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case - ; intros; simpl in *; - try solve [ - congruence - | - destruct H - ; specialize (IHdf1 σ v); specialize (IHdf2 σ v) - ;cut_to IHdf1; trivial - ;match_option; [|tauto] - ;cut_to IHdf2; trivial - ;match_option; tauto - | - destruct H; - specialize (IHdf1 σ v); specialize (IHdf2 σ v); - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1; - cut_to H1; trivial; - match_option; [|tauto]; - cut_to IHdf1; trivial; - match_option; [|tauto]; - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2; - match_option; [|tauto]; - cut_to IHdf2; trivial; - match_option; tauto - | - generalize (eval_fully_closed_not_none σ df); intros; - specialize (IHdf σ v); - simpl in H0; cut_to H0; trivial; - match_option; [|tauto]; - cut_to IHdf; trivial; - match_option; tauto - | - specialize (IHdf σ v); - generalize (eval_fully_closed_not_none σ df); intros; - simpl in H0; cut_to H0; trivial; - match_option; tauto - ]. - - Case "DVector"%string. - apply vectoro_to_ovector_not_none; intros; apply H. - rewrite vforall_forall in H0; apply H0. - - Case "DMatrix"%string. - apply vectoro_to_ovector_not_none; intros. - apply vectoro_to_ovector_not_none; intros; apply H. - rewrite vforall_forall in H0; specialize (H0 i). - rewrite vforall_forall in H0; apply H0. - - Case "Max"%string. - destruct H; - specialize (IHdf1 σ v); specialize (IHdf2 σ v); - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1; - cut_to H1; trivial; - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2; - match_option; [|tauto]. - case_eq ( Rle_dec d d0 ); intros. - cut_to IHdf2; trivial. - cut_to IHdf1; trivial. - - Case "VectorApply"%string. - destruct H. - specialize (IHdf2 σ v0); - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - cut_to H1; trivial. - match_option; [|tauto]. - cut_to IHdf2; trivial. - match_option; [|tauto]. - apply vectoro_to_ovector_not_none; intros. - match_option. - specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i) :: nil) (v, DTfloat)). - cut_to IHdf1; trivial. - tauto. - - Case "MatrixApply"%string. - destruct H. - specialize (IHdf2 σ v0); - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - cut_to H1; trivial. - match_option; [|tauto]. - cut_to IHdf2; trivial. - match_option; [|tauto]. - apply vectoro_to_ovector_not_none; intros. - apply vectoro_to_ovector_not_none; intros. - specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i i0) :: nil) (v, DTfloat)). - cut_to IHdf1; trivial. - match_option; tauto. - - Case "VLossfun"%string. - destruct H. - specialize (IHdf2 σ v); - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - cut_to H1; trivial. - match_option; [|tauto]. - cut_to IHdf2; trivial. - match_option; [|tauto]. - match_option. - apply vectoro_to_ovector_not_none in eqq1. - tauto. - intros. - specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i) :: mk_env_entry (v2, DTfloat) (r i) :: nil) (v1, DTfloat)). - cut_to IHdf1; trivial. - match_option; tauto. - - Case "MLossfun"%string. - destruct H. - specialize (IHdf2 σ v); - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - cut_to H1; trivial. - match_option; [|tauto]. - cut_to IHdf2; trivial. - match_option; [|tauto]. - match_option. - unfold matrixo_to_omatrix in eqq1. - apply vectoro_to_ovector_not_none in eqq1. - tauto. - intros. - apply vectoro_to_ovector_not_none; intros. - specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i i0) :: mk_env_entry (v2, DTfloat) (r i i0) :: nil)(v1, DTfloat)). - cut_to IHdf1; trivial. - match_option; tauto. - Qed. - - Lemma eval_deriv_fully_closed_total {T} (σ:df_env) (df:DefinedFunction UnitAnn T) - (v:var_type): - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> - {d:definition_function_types_interp T | df_eval_deriv σ df v = Some d}. - Proof. - case_eq (df_eval_deriv σ df v); intros. - - now exists d. - - generalize (eval_deriv_fully_closed_not_none σ df v). - intros; simpl in *. - cut_to H1; tauto. - Qed. - - Lemma eval_deriv_genvar_fully_closed_not_none {T} (σ:df_env) (df:DefinedFunction UnitAnn T) - (v:df_env): - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> df_eval_deriv_genvar σ df v <> None. - Proof. - revert σ v. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case - ; intros; simpl in *; - try solve [ - congruence - | - destruct H - ; specialize (IHdf1 σ v); specialize (IHdf2 σ v) - ;cut_to IHdf1; trivial - ;match_option; [|tauto] - ;cut_to IHdf2; trivial - ;match_option; tauto - | - destruct H; - specialize (IHdf1 σ v); specialize (IHdf2 σ v); - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1; - cut_to H1; trivial; - match_option; [|tauto]; - cut_to IHdf1; trivial; - match_option; [|tauto]; - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2; - match_option; [|tauto]; - cut_to IHdf2; trivial; - match_option; tauto - | - generalize (eval_fully_closed_not_none σ df); intros; - specialize (IHdf σ v); - simpl in H0; cut_to H0; trivial; - match_option; [|tauto]; - cut_to IHdf; trivial; - match_option; tauto - | - specialize (IHdf σ v); - generalize (eval_fully_closed_not_none σ df); intros; - simpl in H0; cut_to H0; trivial; - match_option; tauto - ]. - - Case "DVector"%string. - apply vectoro_to_ovector_not_none; intros; apply H. - rewrite vforall_forall in H0; apply H0. - - Case "DMatrix"%string. - apply vectoro_to_ovector_not_none; intros. - apply vectoro_to_ovector_not_none; intros; apply H. - rewrite vforall_forall in H0; specialize (H0 i). - rewrite vforall_forall in H0; apply H0. - - Case "Max"%string. - destruct H; - specialize (IHdf1 σ v); specialize (IHdf2 σ v); - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1; - cut_to H1; trivial; - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2; - match_option; [|tauto]. - case_eq ( Rle_dec d d0 ); intros. - cut_to IHdf2; trivial. - cut_to IHdf1; trivial. - - Case "VectorApply"%string. - destruct H. - specialize (IHdf2 σ v0); - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - cut_to H1; trivial. - match_option; [|tauto]. - cut_to IHdf2; trivial. - match_option; [|tauto]. - apply vectoro_to_ovector_not_none; intros. - match_option. - specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i) :: nil) (mk_genvar_env v)). - cut_to IHdf1; trivial. - tauto. - - Case "MatrixApply"%string. - destruct H. - specialize (IHdf2 σ v0); - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - cut_to H1; trivial. - match_option; [|tauto]. - cut_to IHdf2; trivial. - match_option; [|tauto]. - apply vectoro_to_ovector_not_none; intros. - apply vectoro_to_ovector_not_none; intros. - specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i i0) :: nil) (mk_genvar_env v)). - cut_to IHdf1; trivial. - match_option; tauto. - - Case "VLossfun"%string. - destruct H. - specialize (IHdf2 σ v); - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - cut_to H1; trivial. - match_option; [|tauto]. - cut_to IHdf2; trivial. - match_option; [|tauto]. - match_option. - apply vectoro_to_ovector_not_none in eqq1. - tauto. - intros. - specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i) :: mk_env_entry (v2, DTfloat) (r i) :: nil) (mk_genvar_env v1)). - cut_to IHdf1; trivial. - match_option; tauto. - - Case "MLossfun"%string. - destruct H. - specialize (IHdf2 σ v); - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - cut_to H1; trivial. - match_option; [|tauto]. - cut_to IHdf2; trivial. - match_option; [|tauto]. - match_option. - unfold matrixo_to_omatrix in eqq1. - apply vectoro_to_ovector_not_none in eqq1. - tauto. - intros. - apply vectoro_to_ovector_not_none; intros. - specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i i0) :: mk_env_entry (v2, DTfloat) (r i i0) :: nil) (mk_genvar_env v1)). - cut_to IHdf1; trivial. - match_option; tauto. - Qed. - - Definition scalarMult (T : definition_function_types) (c : float) := - match T return - definition_function_types_interp T -> definition_function_types_interp T with - | DTfloat => fun f => (c * f)%R - | DTVector n => fun f => fun i => (c * f i)%R - | DTMatrix n m => fun f => fun i j => (c * f i j)%R - end. - - Definition dfti_gen_plus {T : definition_function_types} := - match T return - definition_function_types_interp T -> definition_function_types_interp T -> - definition_function_types_interp T with - | DTfloat => fun f g => (f + g)%R - | DTVector n => fun f g => fun i => (f i + g i)%R - | DTMatrix n m => fun f g => fun i j => (f i j + g i j)%R - end. - - Lemma subvar_addvar_scalar_neq (env : df_env) (oval : float) (s : SubVar) (v: var_type) (grad : definition_function_types_interp (snd v)) : - let sv := (s, DTfloat) in - vartlookup env sv = Some oval -> - v <> sv -> - subvar sv (vart_update env v (addvar v env grad)) oval = 0%R. - Proof. - intros. - unfold subvar; simpl. - rewrite lookup_update_neq; trivial. - rewrite H. - lra. - Qed. - - Lemma subvar_addvar_scalar_eq (env : df_env) (s : SubVar) (oval grad : float) : - let v := (s, DTfloat) in - vartlookup env v = Some oval -> - subvar v (vart_update env v (addvar v env grad)) oval = grad. - Proof. - intros. - unfold subvar; simpl. - rewrite lookup_update. - unfold addvar; simpl. - rewrite H. - lra. - Qed. - - Lemma split_subvar (env1 env2: df_env) (oval val1 : float) (s : SubVar) : - let v := (s, DTfloat) in - vartlookup env1 v = Some val1 -> - subvar v env2 oval = (subvar v env2 val1 + subvar v env1 oval)%R. - Proof. - intros. - unfold subvar; simpl. - rewrite H. - case_eq (vartlookup env2 v); intros. - lra. - lra. - Qed. - - Lemma vsum_nil : vsum vnil = 0%R. - Proof. - reflexivity. - Qed. - - Lemma vsum_mult {n} (v : Vector float n) (c : float) : - (c * vsum v)%R = vsum (fun j => (c * v j)%R). - Proof. - unfold vsum, vector_fold_right1, Datatypes.id; simpl. - induction n; [ | destruct n]. - - repeat rewrite vector_fold_right1_dep_0; lra. - - repeat rewrite vector_fold_right1_dep_1; lra. - - repeat rewrite vector_fold_right1_dep_SSn. - rewrite Rmult_plus_distr_l. - specialize (IHn (vdrop_last v)); simpl in IHn. - rewrite IHn. - f_equal. - apply vector_fold_right1_dep_ext. - intros [i pf]; trivial. - Qed. - - Lemma vsum_plus {m:nat} (v1 v2:Vector R m) : - (vsum v1 + vsum v2)%R = vsum (fun i => (v1 i + v2 i)%R). - Proof. - unfold vsum, vector_fold_right1, Datatypes.id; simpl. - induction m; [ | destruct m]. - - repeat rewrite vector_fold_right1_dep_0; lra. - - repeat rewrite vector_fold_right1_dep_1; lra. - - repeat rewrite vector_fold_right1_dep_SSn. - specialize (IHm (vdrop_last v1) (vdrop_last v2)); simpl in IHm. - rewrite (Rplus_comm (vlast v2)). - rewrite (Rplus_assoc (vlast v1)). - rewrite <- (Rplus_assoc _ _ (vlast v2)). - rewrite IHm. - rewrite (Rplus_comm _ (vlast v2)). - rewrite <- Rplus_assoc. - f_equal. - apply vector_fold_right1_dep_ext. - intros [i pf]; trivial. - Qed. - - Lemma vmap_mult {n} (f: float -> float) (v : Vector float n) (c : float) : - forall i : {n' : nat | n' < n}, - (c * (vmap f v) i)%R = (vmap (fun x => (c * f x)%R) v) i. - Proof. - intros. - rewrite vmap_nth. - now rewrite vmap_nth. - Qed. - - Lemma vsum_ext {n} (v v':Vector float n) : vec_eq v v' -> vsum v = vsum v'. - Proof. - apply vector_fold_right1_ext. - Qed. - - Lemma msum_ext {m n} (mat mat':Matrix float m n) : - (forall i j, mat i j = mat' i j) -> msum mat = msum mat'. - Proof. - intros. - apply vsum_ext; intros ?. - repeat rewrite vmap_nth. - apply vsum_ext; intros ?; auto. - Qed. - - Lemma msum_mult {m n} (mat : Matrix float m n) (c : float) : - (c * msum mat)%R = msum (fun i j => (c * mat i j)%R). - Proof. - unfold msum. - rewrite vsum_mult. - apply vsum_ext; intros i. - repeat rewrite vmap_nth. - now rewrite vsum_mult. - Qed. - - Lemma msum_mmap_mult {m n} (mat : Matrix float m n) (c : float) : - (c * msum mat)%R = msum (mmap (fun x => c * x)%R mat). - Proof. - rewrite msum_mult. - apply msum_ext; intros i j. - now rewrite mmap_nth. - Qed. - - Lemma msum_mmap_div_denom {m n} (mat : Matrix float m n) (c : float) : - msum (mmap (fun u : R => (u / c)%R) mat) = (msum mat / c)%R. - Proof. - transitivity (msum (mmap (fun u : R => (/ c * u)%R) mat)). - - apply msum_ext; intros i j. - repeat rewrite mmap_nth. - lra. - - rewrite <- msum_mmap_mult. - lra. - Qed. - - Lemma vsum0 n : vsum (fun _ : {n' : nat | (n' < n)%nat} => 0%R) = 0%R. - Proof. - generalize (vsum_mult (fun _ : {n' : nat | (n' < n)%nat} => 0%R) 0%R); intros HH. - rewrite Rmult_0_l in HH. - symmetry. - simpl in *. - erewrite vsum_ext; [eassumption | ]. - intro; simpl; lra. - Qed. - - Lemma vsum_unitvector {n} (v:Vector R n) i : - vsum (fun j => (v j * UnitVector n i j)%R) = v i. - Proof. - unfold vsum, vector_fold_right1, Datatypes.id, UnitVector; simpl. - revert n v i. - destruct i. - induction n; [ | destruct n]. - - lia. - - repeat rewrite vector_fold_right1_dep_1. - destruct x; [ | lia]; simpl. - field_simplify. - now erewrite index_pf_irrel. - - repeat rewrite vector_fold_right1_dep_SSn. - unfold vlast, vdrop_last; simpl. - destruct (equiv_dec (S n) x). - + ring_simplify. - simpl. - destruct e. - match goal with - | [|- (_ + ?x)%R = _ ] => replace x with 0%R - end. - * ring_simplify. - now erewrite index_pf_irrel. - * rewrite <- (vsum0 (S n)) at 1. - unfold vsum, vector_fold_right1, Fzero, Datatypes.id; simpl. - apply (@vector_fold_right1_dep_ext (fun _ => R)). - intros [??]. - destruct (equiv_dec x (S n)). - -- destruct e. - lia. - -- lra. - + ring_simplify. - unfold equiv, complement in c. - assert (pf:x < S n) by lia. - specialize (IHn (vdrop_last v) pf). - simpl in IHn. - erewrite index_pf_irrel. - rewrite <- IHn. - apply (@vector_fold_right1_dep_ext (fun _ => R)). - now intros [??]. - Qed. - - Lemma msum_unitmatrix {m n} (v:Matrix R m n) i j : - msum (fun k l => (v k l * UnitMatrix m n i j k l)%R) = v i j. - Proof. - unfold msum. - unfold UnitMatrix. - rewrite (vsum_ext _ ( - (fun (k : {n' : nat | n' < m}) => @vsum floatish_R _ - (fun (l : {m' : nat | m' < n}) => - (v k l * - (if equiv_dec (` k) (` i) then if equiv_dec (` l) (` j) then 1%R else 0%R else 0%R))%R)) - - )) - by (intros ?; now rewrite vmap_nth). - rewrite (vsum_ext _ ( - (fun (k : {n' : nat | n' < m}) => (if equiv_dec (` k) (` i) then - @vsum floatish_R _ - (fun (l : {m' : nat | m' < n}%nat) => - ((v k) l * - if equiv_dec (` l) (` j) then 1%R else 0%R))%R else 0%R)) - - )). - - rewrite (vsum_ext _ ( - (fun (k : {n' : nat | n' < m}) => (if equiv_dec (` k) (` i) then - v k j else 0%R)) - - )). - + rewrite (vsum_ext _ ( - (fun (k : {n' : nat | n' < m}) => ((transpose v) j k * @UnitVector floatish_R m i k)%R) - - )). - * now rewrite vsum_unitvector. - * unfold UnitVector; intros ?; simpl. - dest_eqdec; unfold transpose; simpl - ; lra. - + intros ?. - dest_eqdec; trivial. - apply vsum_unitvector. - - intros ?. - dest_eqdec; trivial. - rewrite <- (vsum0 n) at 1. - apply vsum_ext. - intros ?; lra. - Qed. - - Ltac vectoro_assert_forall_in H i - := match type of H with vectoro_to_ovector ?x = Some ?y => - assert (forall i, x i = Some (y i)) end. - - Lemma vartlookup_list_env_iter {A} - (s: SubVar) - (f : A -> df_env -> option df_env) - (l : list A) (env fenv: df_env): - list_env_iter f (Some env) l = Some fenv -> - vartlookup env (s, DTfloat) <> None -> - (forall (env fenv: df_env) (i:A), - f i env = Some fenv -> - vartlookup env (s, DTfloat) <> None -> - vartlookup fenv (s, DTfloat) <> None) -> - vartlookup fenv (s, DTfloat) <> None. - Proof. - intros. - revert H0 H. - generalize env. - induction l. - - simpl; intros. - now invcs H. - - simpl; intros. - generalize (list_env_iter_none f l); intros. - assert (f a env0 <> None); [congruence | ]. - case_eq (f a env0); [|congruence]. - intros. - apply (IHl d). - + specialize (H1 env0 d a). - now apply H1. - + now rewrite H4 in H. - Qed. - - (* - Theorem df_eval_deriv_same {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (s:SubVar) : - let v := (s, DTfloat) in - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> - df_eval_deriv σ df v = df_eval σ (df_deriv df v). - Proof. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case - ; simpl in *;trivial; intros - ; try (destruct H; rewrite IHdf1; trivial; rewrite IHdf2; trivial) - ; try (rewrite IHdf; trivial; do 2 match_option). - - Case "DVector"%string. - f_equal. - apply functional_extensionality; intros. - rewrite vforall_forall in H0. - specialize (H x0); simpl in H. - specialize (H0 x0). - apply H; trivial. - - Case "DMatrix"%string. - f_equal. - apply functional_extensionality; intros. - apply functional_extensionality; intros. - rewrite vforall_forall in H0. - specialize (H x0); simpl in H. - specialize (H0 x0). - rewrite vforall_forall in H0. - apply H; trivial. - - case "Times"%string. - intros; do 4 match_option. - - Case "Divide"%string. - intros; do 4 match_option. - - Case "Sign"%string. - match_option. - generalize (eval_deriv_fully_closed_not_none σ df (s, DTfloat)); tauto. - - Case "PSign"%string. - match_option. - generalize (eval_deriv_fully_closed_not_none σ df (s, DTfloat)); tauto. - - Case "Max"%string. - assert (df_eval σ df1 <> None) by (apply eval_fully_closed_not_none; trivial). - assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial). - match_option; [|congruence]. - match_option; [|congruence]. - assert (df_eval σ (df_deriv df1 (s, DTfloat)) <> None). - apply eval_fully_closed_not_none. - apply fully_closed_deriv; trivial. - assert (df_eval σ (df_deriv df2 (s, DTfloat)) <> None). - apply eval_fully_closed_not_none. - apply fully_closed_deriv; trivial. - destruct (df_eval σ (df_deriv df1 (s, DTfloat))); [|congruence]. - destruct (df_eval σ (df_deriv df2 (s, DTfloat))); [|congruence]. - unfold pos_sign; simpl. - case_eq (Rle_dec d d0); intros; f_equal. - destruct (Rge_dec (d0 - d) 0); lra. - destruct (Rge_dec (d0 - d) 0); lra. - - Case "VectorDot"%string. - do 4 match_option. - rewrite vsum_plus. - do 2 f_equal. - f_equal. - apply functional_extensionality; intros. - lra. - - Case "MatrixVectorMult"%string. - do 4 match_option. - f_equal. - apply functional_extensionality; intros. - unfold matrix_vector_mult. - rewrite vsum_plus; simpl. - f_equal. - apply functional_extensionality; intros. - lra. - - Case "MatrixMult"%string. - do 4 match_option. - f_equal. - apply functional_extensionality; intros. - apply functional_extensionality; intros. - unfold matrix_mult. - rewrite vsum_plus; simpl; f_equal. - apply functional_extensionality; intros. - lra. - - Case "VectorScalMult"%string. - do 4 match_option. - f_equal. - apply functional_extensionality; intros. - lra. - - Case "MatrixScalMult"%string. - do 4 match_option. - f_equal. - apply functional_extensionality; intros. - apply functional_extensionality; intros. - lra. - - Case "VectorApply"%string. - destruct H. - assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial). - match_option; [|congruence]. - rewrite IHdf2; trivial. - match_option. - + f_equal. - apply functional_extensionality; intros. - assert ( df_eval_deriv [mk_env_entry (v, DTfloat) (d x)] df1 (v, DTfloat) = - df_eval σ (df_subst (df_deriv df1 (v, DTfloat)) (v, DTfloat) - (VectorElem () df2 x))). - XXX - now rewrite H2. - + assert ( df_eval σ (df_deriv df2 (s, DTfloat)) <> None ); [|tauto]. - apply eval_fully_closed_not_none. - apply fully_closed_deriv; trivial. - - Case "MatrixApply"%string. - destruct H. - assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial). - match_option; [|congruence]. - rewrite IHdf2; trivial. - match_option. - + f_equal. - apply functional_extensionality; intros. - apply functional_extensionality; intros. - assert ( df_eval_deriv [mk_env_entry (v, DTfloat) (d x x0)] df1 (v, DTfloat) = - df_eval σ (df_subst (df_deriv df1 (v, DTfloat)) (v, DTfloat) - (MatrixElem () df2 x x0))). - XXX - now rewrite H2. - + assert ( df_eval σ (df_deriv df2 (s, DTfloat)) <> None ); [|tauto]. - apply eval_fully_closed_not_none. - apply fully_closed_deriv; trivial. - - Case "VLossfun"%string. - destruct H. - assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial). - match_option; [|congruence]. - rewrite IHdf2; trivial. - match_option. - do 2 match_option. - + do 2 f_equal. - apply functional_extensionality; intros. - XXX - + generalize (vectoro_to_ovector_exists_None eqq2). - intros; destruct H2. - assert (df_eval - σ - (df_subst (df_subst (df_deriv df1 (v1, DTfloat)) (v1, DTfloat) (VectorElem () df2 x)) - (v2, DTfloat) (Number () (r x))) <> None); [|tauto]. - apply eval_fully_closed_not_none. - apply fully_closed_subst; simpl; [|trivial]. - apply fully_closed_subst; simpl. - * apply fully_closed_deriv. - now apply fully_closed_over_pair. - * now apply fully_closed_over_cons. - + generalize (vectoro_to_ovector_exists_None eqq1). - intros; destruct H2. - match_option_in e. - assert (df_eval_deriv [mk_env_entry (v1, DTfloat) (d x); mk_env_entry (v2, DTfloat) (r x)] - df1 (v1, DTfloat) <> None); [|tauto]. - apply eval_deriv_fully_closed_not_none; trivial. - - Case "MLossfun"%string. - destruct H. - assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial). - match_option; [|congruence]. - rewrite IHdf2; trivial. - match_option. - do 2 match_option. - + do 2 f_equal. - XXX - + generalize (vectoro_to_ovector_exists_None eqq2); intros; destruct H2. - generalize (vectoro_to_ovector_exists_None e); intros; destruct H2. - match_option_in e0. - match_option_in eqq3. - assert (df_eval σ - (df_subst - (df_subst (df_deriv df1 (v1, DTfloat)) (v1, DTfloat) (MatrixElem () df2 x x0)) - (v2, DTfloat) (Number () (r x x0))) <> None); [|tauto]. - apply eval_fully_closed_not_none. - apply fully_closed_subst; simpl; [|trivial]. - apply fully_closed_subst; simpl. - * apply fully_closed_deriv. - now apply fully_closed_over_pair. - * now apply fully_closed_over_cons. - + generalize (vectoro_to_ovector_exists_None eqq1). - intros; destruct H2. - generalize (vectoro_to_ovector_exists_None e). - intros; destruct H2. - match_option_in e0. - assert (df_eval_deriv [mk_env_entry (v1, DTfloat) (d x x0); - mk_env_entry (v2, DTfloat) (r x x0)] - df1 (v1, DTfloat) <> None); [|tauto]. - apply eval_deriv_fully_closed_not_none; trivial. - *) - - Theorem df_eval_deriv_scalar_same (σ:df_env) (df:DefinedFunction UnitAnn DTfloat) (s:SubVar) : - let v := (s, DTfloat) in - let vl := map (fun ve => projT1 ve) σ in - is_scalar_function df -> - fully_closed_over df vl -> - df_eval_deriv σ df v = df_eval σ (df_deriv df v). - Proof. - simpl. - intros is_scalar. - generalize is_scalar. - pattern df. - revert df is_scalar. - DefinedFunction_scalar_cases (apply is_scalar_function_ind) Case - ; simpl; trivial; intros - ; try (destruct H1; destruct is_scalar; - specialize (H H3 H1); specialize (H0 H4 H2); now rewrite H, H0) - ; try (rewrite H; trivial; do 2 match_option). - - Case "Times"%string. - destruct H1. - destruct is_scalar. - cut_to H; trivial. - cut_to H0; trivial. - rewrite H, H0. - do 4 match_option. - - Case "Divide"%string. - destruct H1. - destruct is_scalar. - cut_to H; trivial. - cut_to H0; trivial. - rewrite H, H0. - do 4 match_option. - - Case "Sign"%string. - match_option. - assert ( df_eval_deriv σ e (s, DTfloat) <> None); [|tauto]. - apply eval_deriv_fully_closed_not_none; trivial. - - Case "PSign"%string. - match_option. - assert ( df_eval_deriv σ e (s, DTfloat) <> None); [|tauto]. - apply eval_deriv_fully_closed_not_none; trivial. - - Case "Max"%string. - destruct is_scalar; destruct H1. - cut_to H; trivial. - cut_to H0; trivial. - rewrite H, H0. - assert (df_eval σ l <> None) by (apply eval_fully_closed_not_none; trivial). - assert (df_eval σ r <> None) by (apply eval_fully_closed_not_none; trivial). - match_option; [|congruence]. - match_option; [|congruence]. - assert (df_eval σ (df_deriv l (s, DTfloat)) <> None). - apply eval_fully_closed_not_none. - apply fully_closed_deriv; trivial. - assert (df_eval σ (df_deriv r (s, DTfloat)) <> None). - apply eval_fully_closed_not_none. - apply fully_closed_deriv; trivial. - destruct (df_eval σ (df_deriv l (s, DTfloat))); [|congruence]. - destruct (df_eval σ (df_deriv r (s, DTfloat))); [|congruence]. - unfold pos_sign; simpl. - case_eq (Rle_dec d d0); intros; f_equal. - destruct (Rge_dec (d0 - d) 0); lra. - destruct (Rge_dec (d0 - d) 0); lra. - Qed. - - Lemma vartlookup_list_env_iter2 {A} - (s: SubVar) - {f : A -> df_env -> option df_env} - {l : list A} {env fenv: df_env}: - list_env_iter f (Some env) l = Some fenv -> - vartlookup env (s, DTfloat) <> None -> - (forall (env fenv: df_env) (i:A), - f i env = Some fenv -> - vartlookup env (s, DTfloat) <> None -> - vartlookup fenv (s, DTfloat) <> None) -> - vartlookup fenv (s, DTfloat) <> None. - Proof. - apply (vartlookup_list_env_iter s f l env fenv). - Qed. - - Lemma scalarMult_list_env_iter - (s: SubVar) (c val1 val2:float) (A :Type) - (f g : A -> df_env -> option df_env) - (l : list A) (env1 env2 fenv1 fenv2: df_env): - list_env_iter f (Some env1) l = Some fenv1 -> - list_env_iter g (Some env2) l = Some fenv2 -> - vartlookup env1 (s, DTfloat) = Some val1 -> - vartlookup env2 (s, DTfloat) = Some val2 -> - (forall (i:A) (env1 env2 fenv1 fenv2: df_env) (v1 v2: float), - vartlookup env1 (s, DTfloat) = Some v1 -> - vartlookup env2 (s, DTfloat) = Some v2 -> - f i env1 = Some fenv1 -> g i env2 = Some fenv2 -> - subvar (s, DTfloat) fenv1 v1 = (c * subvar (s, DTfloat) fenv2 v2)%R) -> - (forall (env fenv: df_env) (i:A), - f i env = Some fenv -> - vartlookup env (s, DTfloat) <> None -> - vartlookup fenv (s, DTfloat) <> None) -> - (forall (env fenv: df_env) (i:A), - g i env = Some fenv -> - vartlookup env (s, DTfloat) <> None -> - vartlookup fenv (s, DTfloat) <> None) -> - subvar (s, DTfloat) fenv1 val1 = (c * subvar (s, DTfloat) fenv2 val2)%R. - Proof. - intros. - generalize (vartlookup_list_env_iter s f l env1 fenv1); intros. - assert (vartlookup env1 (s, DTfloat) <> None). - rewrite H1; discriminate. - assert (vartlookup env2 (s, DTfloat) <> None). - rewrite H2; discriminate. - specialize (H6 H H7 H4). - generalize (vartlookup_list_env_iter s g l env2 fenv2); intros. - specialize (H9 H0 H8 H5). - revert H1 H2 H H0. - generalize env1 env2 val1 val2. - induction l. - - intros. - unfold subvar; simpl. - unfold list_env_iter in H; simpl in H. - unfold list_env_iter in H0; simpl in H0. - invcs H; invcs H0. - rewrite H1; rewrite H2; lra. - - simpl; intros. - generalize (list_env_iter_none f l); intros. - assert (f a env0 <> None); [congruence | ]. - case_eq (f a env0); [intros|congruence]. - generalize (list_env_iter_none g l); intros. - assert (g a env3 <> None); [congruence | ]. - case_eq (g a env3); [intros|congruence]. - assert (vartlookup d (s, DTfloat) <> None). - apply (H4 env0 d a); trivial; congruence. - assert (vartlookup d0 (s, DTfloat) <> None). - apply (H5 env3 d0 a); trivial; congruence. - case_eq (vartlookup d (s, DTfloat)); [intros | tauto]. - case_eq (vartlookup d0 (s, DTfloat)); [intros | tauto]. - specialize (IHl d d0 d1 d2). - specialize (H3 a env0 env3 d d0 val0 val3). - rewrite (split_subvar d fenv1 val0 d1); trivial. - rewrite (split_subvar d0 fenv2 val3 d2); trivial. - specialize (H3 H1 H2 H12 H15). - rewrite H12 in H. - rewrite H15 in H0. - specialize (IHl H18 H19 H H0). - lra. - Qed. - - Lemma list_env_iter_subvar_env2 - (s: SubVar) (val1 val2:float) (A :Type) - (f g : A -> df_env -> option df_env) - (l : list A) (env1 env2 fenv1 fenv2: df_env): - list_env_iter f (Some env1) l = Some fenv1 -> - list_env_iter g (Some env2) l = Some fenv2 -> - vartlookup env1 (s, DTfloat) = Some val1 -> - vartlookup env2 (s, DTfloat) = Some val2 -> - (forall (i:A) (env1 env2 fenv1 fenv2: df_env) (v1 v2: float), - vartlookup env1 (s, DTfloat) = Some v1 -> - vartlookup env2 (s, DTfloat) = Some v2 -> - f i env1 = Some fenv1 -> g i env2 = Some fenv2 -> - subvar (s, DTfloat) fenv1 v1 = subvar (s, DTfloat) fenv2 v2) -> - (forall (env fenv: df_env) (i:A), - f i env = Some fenv -> - vartlookup env (s, DTfloat) <> None -> - vartlookup fenv (s, DTfloat) <> None) -> - (forall (env fenv: df_env) (i:A), - g i env = Some fenv -> - vartlookup env (s, DTfloat) <> None -> - vartlookup fenv (s, DTfloat) <> None) -> - subvar (s, DTfloat) fenv1 val1 = subvar (s, DTfloat) fenv2 val2. - Proof. - intros. - generalize (scalarMult_list_env_iter s 1%R val1 val2 A f g l env1 env2 fenv1 fenv2). - intros. - specialize (H6 H H0 H1 H2). - cut_to H6. - now replace (1 * subvar (s, DTfloat) fenv2 val2)%R with (subvar (s, DTfloat) fenv2 val2) in H6 by lra. - intros. - replace (1 * subvar (s, DTfloat) fenv3 v2)%R with (subvar (s, DTfloat) fenv3 v2) by lra. - apply (H3 i env0 env3 fenv0 fenv3); trivial. - apply H4. - apply H5. - Qed. - - Lemma scalarMult_backprop_grad_scalar {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (s: SubVar) (grad_env1 grad_env2:df_env) (grad : definition_function_types_interp T) (c:float) : - let v := (s, DTfloat) in - vartlookup grad_env1 v <> None -> vartlookup grad_env2 v <> None -> - df_eval_backprop_deriv σ df grad_env1 (scalarMult T c grad) <> None -> - df_eval_backprop_deriv σ df grad_env2 grad <> None -> - df_eval_backprop_delta σ df v grad_env1 (scalarMult T c grad) = - lift (fun e => scalarMult (snd v) c e) (df_eval_backprop_delta σ df v grad_env2 grad). - Proof. - revert grad_env1 grad_env2. - unfold df_eval_backprop_delta. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case - ; simpl; intros grad_env1 grad_env2 neq1 neq2; intros. - - Case "Number"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [|tauto]. - intros; simpl; f_equal. - unfold subvar; simpl. - match_destr; match_destr. - inversion H1; subst. - inversion H2; subst. - lra. - - Case "Constant"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - intros; simpl; f_equal. - unfold subvar; simpl. - match_destr; match_destr. - inversion H1; subst. - inversion H2; subst. - lra. - - Case "DVector"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - intros; simpl. - unfold lift. - match_option; [|tauto]. - case_eq (two_vector_env_iter_alt - (fun (x0 : DefinedFunction Ann DTfloat) (g : R) (genv : df_env) => - df_eval_backprop_deriv σ x0 genv g) grad_env2 x grad); [|tauto]. - unfold two_vector_env_iter_alt in *. - intros; f_equal. - apply (scalarMult_list_env_iter - s c d0 d {n' : nat | n' < n} - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (c * grad i)%R) - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (grad i)) - (bounded_seq0 n) grad_env1 grad_env2); trivial. - + intros. - specialize (H i (grad i) env1 env2). - assert (vartlookup env1 (s, DTfloat) <> None); [congruence|]. - assert (vartlookup env2 (s, DTfloat) <> None); [congruence|]. - specialize (H H9 H10). - assert (df_eval_backprop_deriv σ (x i) env1 (c * grad i)%R <> None); [congruence|]. - assert (df_eval_backprop_deriv σ (x i) env2 (grad i) <> None); [congruence|]. - specialize (H H11 H12). - unfold lift in H; simpl in H. - rewrite H5, H6, H7, H8 in H. - now inversion H. - + intros. - now apply (df_eval_backprop_deriv_preserves_lookup_not_none H5). - + intros. - now apply (df_eval_backprop_deriv_preserves_lookup_not_none H5). - - Case "DMatrix"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - intros; simpl. - unfold lift. - match_option; [|tauto]. - case_eq (two_matrix_env_iter_alt - (fun (x0 : DefinedFunction Ann DTfloat) (g : R) (genv : df_env) => - df_eval_backprop_deriv σ x0 genv g) grad_env2 x grad); [|tauto]. - intros; f_equal. - unfold two_matrix_env_iter_alt in *. - apply (scalarMult_list_env_iter - s c d0 d {n' : nat | n' < n} - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (c * grad i j)%R) - (Some env) (bounded_seq0 m)) - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) (Some env) - (bounded_seq0 m)) - (bounded_seq0 n) grad_env1 grad_env2); trivial. - + intros. - apply (scalarMult_list_env_iter - s c v1 v2 {m' : nat | m' < m} - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (c * grad i j)%R) - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (bounded_seq0 m) env1 env2); trivial. - * intros. - specialize (H i i0 (grad i i0) env0 env3). - assert (vartlookup env0 (s, DTfloat) <> None); [congruence|]. - assert (vartlookup env3 (s, DTfloat) <> None); [congruence|]. - specialize (H H13 H14). - assert (df_eval_backprop_deriv σ (x i i0) env0 (c * grad i i0)%R <> None); [congruence|]. - assert (df_eval_backprop_deriv σ (x i i0) env3 (grad i i0) <> None); [congruence|]. - specialize (H H15 H16). - unfold lift in H; simpl in H. - rewrite H9, H10, H11, H12 in H. - now inversion H. - * intros. - now apply (df_eval_backprop_deriv_preserves_lookup_not_none H9). - * intros. - now apply (df_eval_backprop_deriv_preserves_lookup_not_none H9). - + intros. - apply (vartlookup_list_env_iter - s - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (c * grad i j)%R) - (bounded_seq0 m) env fenv); trivial; intros. - now apply (df_eval_backprop_deriv_preserves_lookup_not_none H7). - + intros. - apply (vartlookup_list_env_iter - s - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (bounded_seq0 m) env fenv); trivial; intros. - now apply (df_eval_backprop_deriv_preserves_lookup_not_none H7). - - Case "Var"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - intros. - destruct (vart_dec v (s, DTfloat)). - + subst; simpl. - rewrite H2; simpl. - rewrite subvar_addvar_scalar_eq; trivial. - rewrite H1; simpl. - now rewrite subvar_addvar_scalar_eq. - + case_eq (vartlookup grad_env1 v); intros; simpl. - * case_eq (vartlookup grad_env2 v); intros; simpl; f_equal. - -- rewrite subvar_addvar_scalar_neq; trivial. - rewrite subvar_addvar_scalar_neq; trivial. - lra. - -- rewrite subvar_addvar_scalar_neq; trivial. - unfold subvar; simpl. - rewrite H1. - lra. - * case_eq (vartlookup grad_env2 v); intros; simpl; f_equal. - -- rewrite subvar_addvar_scalar_neq; trivial. - unfold subvar; simpl. - rewrite H2. - lra. - -- unfold subvar; simpl. - rewrite H2; rewrite H1. - lra. - - Case "Plus"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - specialize (IHdf1 grad grad_env1 grad_env2); intros. - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 (c * grad)%R); intros. - rewrite H3 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros. - rewrite H4 in H0; simpl in H0. - + specialize (IHdf2 grad d1 d2). - rewrite H3 in IHdf1; rewrite H4 in IHdf1. - assert (vartlookup d1 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1). - case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2). - case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros. - rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d1 (c * grad)%R); [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d2 grad); [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d1 d5 d0 d3) by trivial. - rewrite (split_subvar d2 d6 d d4) by trivial. - rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by - (apply IHdf2; trivial; discriminate). - inversion H11; inversion H12. - rewrite H14; rewrite H15; lra. - + now rewrite H4 in H0. - + now rewrite H3 in H. - - Case "Minus"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - specialize (IHdf1 grad grad_env1 grad_env2); intros. - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 (c * grad)%R); intros. - rewrite H3 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros. - rewrite H4 in H0; simpl in H0. - + specialize (IHdf2 (-grad)%R d1 d2). - rewrite H3 in IHdf1; rewrite H4 in IHdf1. - assert (vartlookup d1 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1). - case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2). - case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros. - rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d1 (-(c * grad))%R); [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d2 (-grad)%R); [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d1 d5 d0 d3) by trivial. - rewrite (split_subvar d2 d6 d d4) by trivial. - replace (c * -grad)%R with (-(c*grad))%R in IHdf2 by lra. - rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by - (apply IHdf2; trivial; discriminate). - inversion H11; inversion H12. - rewrite H14; rewrite H15; lra. - + now rewrite H4 in H0. - + now rewrite H3 in H. - - Case "Times"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - case_eq ( df_eval σ df1); [|tauto]. - case_eq ( df_eval σ df2); [|tauto]; intros. - specialize (IHdf1 (d * grad)%R grad_env1 grad_env2). - rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 (d * (c * grad))%R); intros. - rewrite H1, H2, H5 in H; simpl in H. - rewrite H1, H2 in H0; simpl in H0. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 (d * grad)%R); intros. - rewrite H6 in H0; simpl in H0. - + specialize (IHdf2 (d0 * grad)%R d3 d4). - replace (c * (d * grad))%R with (d * (c * grad))%R in IHdf1 by lra. - rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1). - case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2). - case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros. - rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d3 (d0 * (c * grad))%R); [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d4 (d0 * grad)%R); [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d3 d7 d2 d5) by trivial. - rewrite (split_subvar d4 d8 d1 d6) by trivial. - replace (c * (d0 * grad))%R with (d0 * (c*grad))%R in IHdf2 by lra. - rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d3 d2) = Some (c * subvar (s, DTfloat) d4 d1)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d7 d5) = Some (c * subvar (s, DTfloat) d8 d6)%R) by - (apply IHdf2; trivial; discriminate). - inversion H13; inversion H14. - rewrite H16; rewrite H17; lra. - + now rewrite H6 in H0. - + rewrite H2 in H; rewrite H1 in H. - now rewrite H5 in H. - - Case "Divide"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - case_eq ( df_eval σ df1); [|tauto]. - case_eq ( df_eval σ df2); [|tauto]; intros. - specialize (IHdf1 (grad / d)%R grad_env1 grad_env2). - rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 (c * grad / d)%R); intros. - rewrite H1 in H; rewrite H2 in H; simpl in H. - rewrite H1 in H0; rewrite H2 in H0; simpl in H0. - rewrite H5 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 (grad / d)%R); intros. - rewrite H6 in H0; simpl in H0. - + specialize (IHdf2 (- d0 / (d*d) * grad)%R d3 d4). - replace (c * (grad / d))%R with (c * grad / d )%R in IHdf1 by lra. - rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1). - case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2). - case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros. - rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d3 (- d0 /(d * d) * (c * grad))%R); [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d4 (- d0 / (d * d) * grad)%R); [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d3 d7 d2 d5) by trivial. - rewrite (split_subvar d4 d8 d1 d6) by trivial. - replace (c * (-d0 / (d * d) * grad))%R with (- d0/(d * d) * (c*grad))%R in IHdf2 by lra. - rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d3 d2) = Some (c * subvar (s, DTfloat) d4 d1)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d7 d5) = Some (c * subvar (s, DTfloat) d8 d6)%R) by - (apply IHdf2; trivial; discriminate). - inversion H13; inversion H14. - rewrite H16; rewrite H17; lra. - + now rewrite H6 in H0. - + rewrite H2 in H; rewrite H1 in H. - now rewrite H5 in H. - - Case "Square"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - case_eq (df_eval σ df); [ | tauto]; intros. - specialize (IHdf (2 * d1 * grad)%R grad_env1 grad_env2); simpl in *. - rewrite H1 in IHdf; rewrite H2 in IHdf. - replace (2 * d1 * (c * grad))%R with (c * (2 * d1 * grad))%R by lra. - rewrite H3 in H; rewrite H3 in H0; simpl in *. - replace (2 * d1 * (c * grad))%R with (c * (2 * d1 * grad))%R in H by lra. - now apply IHdf. - - Case "Exp"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - case_eq (df_eval σ df); [ | tauto]; intros. - specialize (IHdf (grad * exp d1)%R grad_env1 grad_env2); simpl in *. - replace (c * grad * exp d1)%R with (c * (grad * exp d1))%R by lra. - rewrite H1 in IHdf; rewrite H2 in IHdf. - rewrite H3 in H; rewrite H3 in H0. - replace (c * grad * exp d1)%R with (c * (grad * exp d1))%R in H by lra. - now apply IHdf. - - Case "Log"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - case_eq (df_eval σ df); [ | tauto]; intros. - specialize (IHdf (grad / d1)%R grad_env1 grad_env2 ); simpl in *. - replace (c * grad / d1)%R with (c * (grad / d1))%R by lra. - rewrite H1 in IHdf; rewrite H2 in IHdf. - rewrite H3 in H; rewrite H3 in H0. - replace (c * grad / d1)%R with (c * (grad / d1))%R in H by lra. - now apply IHdf. - - Case "Abs"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - case_eq (df_eval σ df); [ | tauto]; intros. - specialize (IHdf (grad * sign d1)%R grad_env1 grad_env2); simpl in *. - replace (c * grad * (@sign floatish_R d1))%R - with (c * (grad * (@sign floatish_R d1)))%R by lra. - rewrite H1 in IHdf; rewrite H2 in IHdf. - rewrite H3 in H; rewrite H3 in H0. - replace (c * grad * (@sign floatish_R d1))%R - with (c * (grad * (@sign floatish_R d1)))%R in H by lra. - now apply IHdf. - - Case "Sign"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - specialize (IHdf (0)%R grad_env1 grad_env2); simpl in *. - replace (0%R) with (c * 0)%R at 1 by lra. - rewrite H1 in IHdf; rewrite H2 in IHdf. - replace (0%R) with (c * 0)%R in H by lra. - now apply IHdf. - - Case "PSign"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - specialize (IHdf (0)%R grad_env1 grad_env2); simpl in *. - replace (0%R) with (c * 0)%R at 1 by lra. - rewrite H1 in IHdf; rewrite H2 in IHdf. - replace (0%R) with (c * 0)%R in H by lra. - now apply IHdf. - - Case "Max"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - case_eq (df_eval σ df1); [ | tauto]; intros. - case_eq (df_eval σ df2); [ | tauto]; intros. - rewrite H3 in H; rewrite H3 in H0. - rewrite H4 in H; rewrite H4 in H0. - case_eq (Rle_dec d1 d2); intros. - + specialize (IHdf2 grad grad_env1 grad_env2); simpl in *. - rewrite H1 in IHdf2; rewrite H2 in IHdf2. - rewrite H5 in H; rewrite H5 in H0. - now apply IHdf2. - + specialize (IHdf1 grad grad_env1 grad_env2); simpl in *. - rewrite H1 in IHdf1; rewrite H2 in IHdf1. - rewrite H5 in H; rewrite H5 in H0. - now apply IHdf1. - - Case "VectorDot"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - case_eq (df_eval σ df1); [ | tauto]; intros. - case_eq (df_eval σ df2); [ | tauto]; intros. - specialize (IHdf1 (vmap (fun rv => (rv * grad)%R) d2) grad_env1 grad_env2). - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 - (vmap (fun rv : R => (rv * (c * grad))%R) d2)); intros. - rewrite H3 in H; rewrite H4 in H; rewrite H5 in H; simpl in H. - rewrite H3 in H0; rewrite H4 in H0; simpl in H0. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 - (vmap (fun rv : R => (rv * grad)%R) d2)); intros. - rewrite H6 in H0; simpl in H0. - + specialize (IHdf2 (vmap (fun lv => (lv *grad)%R) d1) d3 d4). - replace (fun i => (c * vmap (fun rv : R => rv * grad) d2 i)%R) with - (vmap (fun rv : R => (rv * (c * grad))%R) d2) in IHdf1. - rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1). - case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2). - case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros. - rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d3 - (vmap (fun lv : R => (lv * (c * grad))%R) d1)) - ; [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d4 (vmap (fun lv : R => (lv * grad)%R) d1)) - ; [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d3 d7 d d5) by trivial. - rewrite (split_subvar d4 d8 d0 d6) by trivial. - replace - (fun i : {n' : nat | n' < n} => (c * vmap (fun lv : R => lv * grad) d1 i)%R) with - (vmap (fun lv : R => (lv * (c * grad))%R) d1) in IHdf2. - rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d3 d) = Some (c * subvar (s, DTfloat) d4 d0)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d7 d5) = Some (c * subvar (s, DTfloat) d8 d6)%R) by - (apply IHdf2; trivial; discriminate). - inversion H13; inversion H14. - rewrite H16; rewrite H17; lra. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vmap_mult. - assert ((fun lv => (lv * (c * grad))%R) = (fun x0 => (c * (x0 * grad))%R)). - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - now rewrite H13. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vmap_mult. - assert ((fun rv => (rv * (c * grad))%R) = (fun x0 => (c * (x0 * grad))%R)). - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - now rewrite H7. - + now rewrite H6 in H0. - + rewrite H3 in H; rewrite H4 in H. - now rewrite H5 in H. - - Case "VectorSum"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - specialize (IHdf (ConstVector n grad) grad_env1 grad_env2). - rewrite H1 in IHdf; rewrite H2 in IHdf; simpl in IHdf. - replace (ConstVector n (c * grad)%R) with - (fun i => (c * ConstVector n grad i)%R). - now apply IHdf. - now unfold ConstVector. - - Case "MatrixSum"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - specialize (IHdf (ConstMatrix m n grad) grad_env1 grad_env2). - rewrite H1 in IHdf; rewrite H2 in IHdf; simpl in IHdf. - replace (ConstMatrix m n (c * grad)%R) with - (fun i j => (c * ConstMatrix m n grad i j)%R). - now apply IHdf. - now unfold ConstMatrix. - - Case "VectorElem"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - specialize (IHdf (fun k => - if equiv_dec (` k) (` i) then grad else 0%R) grad_env1 grad_env2). - rewrite H1 in IHdf; rewrite H2 in IHdf; simpl in *. - replace (fun i0 => (c * (if equiv_dec (` i0) (` i) then grad else 0))%R) with - (fun k : {n' : nat | n' < n} => - if equiv_dec (` k) (` i) then (c * grad)%R else 0%R) in IHdf. - now rewrite IHdf. - apply FunctionalExtensionality.functional_extensionality; intros. - destruct (equiv_dec (` x) (` i)); lra. - - Case "MatrixElem"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - specialize (IHdf (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) => - if equiv_dec (` k1) (` i) - then if equiv_dec (` k2) (` j) then grad else 0%R else 0%R) - grad_env1 grad_env2). - rewrite H1 in IHdf; rewrite H2 in IHdf; simpl in *. - replace (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) => - if equiv_dec (` k1) (` i) - then if equiv_dec (` k2) (` j) then (c * grad)%R else 0%R else 0%R) with - (fun (i0 : {n' : nat | n' < m}) (j0 : {m' : nat | m' < n}) => - (c * - (if equiv_dec (` i0) (` i) - then if equiv_dec (` j0) (` j) then grad else 0 - else 0))%R) in *. - + now rewrite IHdf. - + apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - destruct (equiv_dec (` x) (` i)); [|lra]. - destruct (equiv_dec (` x0) (` j)); lra. - - Case "MatrixVectorMult"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - case_eq ( df_eval σ df1); [|tauto]. - case_eq ( df_eval σ df2); [|tauto]; intros. - specialize (IHdf1 (fun i j => (grad i * d j)%R) grad_env1 grad_env2). - rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 - (fun i j => (c * grad i * d j)%R)); intros. - rewrite H1 in H; rewrite H2 in H; simpl in H. - rewrite H1 in H0; rewrite H2 in H0; simpl in H0. - rewrite H5 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 - (fun i j => (grad i * d j)%R)); intros. - rewrite H6 in H0; simpl in H0. - + specialize (IHdf2 (matrix_vector_mult (fun i j => d0 j i) grad) d3 d4). - replace - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (c * (grad i * d j))%R) with - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (c * grad i * d j)%R) in IHdf1. - * rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1). - case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2). - case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros. - rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *. - unfold lift; match_case; [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d4 - (matrix_vector_mult (fun i j => d0 j i) grad)) - ; [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d3 d7 d2 d5) by trivial. - rewrite (split_subvar d4 d8 d1 d6) by trivial. - replace (fun i : {n' : nat | n' < n} => - (c * - (@matrix_vector_mult floatish_R _ _ - (fun (i0 : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) => - d0 j i0) grad) i)%R) with - (matrix_vector_mult - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => d0 j i) - (fun i : {n' : nat | n' < m} => (c * grad i)%R)) in IHdf2. - -- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d3 d2) = - Some (c * subvar (s, DTfloat) d4 d1)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d7 d5) = - Some (c * subvar (s, DTfloat) d8 d6)%R) by - (apply IHdf2; trivial; discriminate). - inversion H13; inversion H14. - rewrite H16; rewrite H17; lra. - -- unfold matrix_vector_mult. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vsum_mult; f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - simpl; lra. - * apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - + now rewrite H6 in H0. - + rewrite H2 in H; rewrite H1 in H. - now rewrite H5 in H. - - Case "MatrixVectorAdd"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - specialize (IHdf1 grad grad_env1 grad_env2); intros. - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i j => (c * grad i j)%R)) - ; intros. - + rewrite H3 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros. - * rewrite H4 in H0; simpl in H0. - match_option_in H0; [|tauto]. - match_option_in H; [|tauto]. - rewrite H3 in IHdf1. - unfold lift. - f_equal. - rewrite H4 in IHdf1. - unfold lift in IHdf1. - assert (vartlookup d1 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1). - case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2). - case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros. - rewrite (split_subvar d1 d4 d0 d5) by trivial. - rewrite (split_subvar d2 d3 d d6) by trivial. - assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) - by (apply IHdf1; trivial; discriminate). - inversion H9; rewrite H11. - assert (Some (subvar (s, DTfloat) d4 d5) = Some (c * subvar (s, DTfloat) d3 d6)%R). - -- f_equal. - apply (scalarMult_list_env_iter - s c d5 d6 {m' : nat | m' < n} - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv - σ df2 env - (transpose - (fun (i0 : {n' : nat | n' < m}) - (j : {m' : nat | m' < n}) => - (c * grad i0 j)%R) i)) - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad i)) - (bounded_seq0 n) d1 d2 d4 d3); trivial. - ++ intros. - assert (vartlookup env1 (s, DTfloat) <> None); [congruence|]. - assert (vartlookup env2 (s, DTfloat) <> None); [congruence|]. - assert (df_eval_backprop_deriv - σ df2 env1 - (transpose - (fun (i0 : {n' : nat | n' < m}) - (j : {m' : nat | m' < n}) => (c * grad i0 j)%R) i) <> None) - ; [congruence|]. - assert (df_eval_backprop_deriv σ df2 env2 (transpose grad i) <> None) - ;[congruence|]. - specialize (IHdf2 (transpose grad i) env1 env2). - specialize (IHdf2 H15 H16). - specialize (IHdf2 H17 H18). - unfold lift in IHdf2; simpl in IHdf2. - rewrite H10, H12, H14 in IHdf2. - unfold transpose in IHdf2; unfold transpose in H13. - rewrite H13 in IHdf2; now inversion IHdf2. - ++ intros. - now apply (df_eval_backprop_deriv_preserves_lookup_not_none H10). - ++ intros. - now apply (df_eval_backprop_deriv_preserves_lookup_not_none H10). - -- inversion H10; rewrite H13; lra. - * rewrite H4 in H0; tauto. - + rewrite H3 in H; tauto. - - Case "MatrixMult"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - case_eq ( df_eval σ df1); [|tauto]. - case_eq ( df_eval σ df2); [|tauto]; intros. - specialize (IHdf1 (matrix_mult grad (fun i j => d j i)) grad_env1 grad_env2). - rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 - (matrix_mult (fun i j => (c * grad i j)%R) - (fun i j => d j i))); intros. - rewrite H1 in H; rewrite H2 in H; simpl in H. - rewrite H1 in H0; rewrite H2 in H0; simpl in H0. - rewrite H5 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 - (matrix_mult grad (fun i j => d j i))); intros. - rewrite H6 in H0; simpl in H0. - + specialize (IHdf2 (matrix_mult (fun i j => d0 j i) grad) d3 d4). - replace (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < p}) => - (c * - (@matrix_mult floatish_R m n p grad - (fun (i0 : {n' : nat | (n' < n)%nat}) (j0 : {m' : nat | (m' < p)%nat}) => - d j0 i0)) i j)%R) with - (matrix_mult (fun i j => (c * grad i j)%R) - (fun i j => d j i)) in IHdf1. - * rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1). - case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2). - case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros. - rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *. - unfold lift; match_case; [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d4 (matrix_mult (fun i j => d0 j i) grad)) - ; [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d3 d7 d2 d5) by trivial. - rewrite (split_subvar d4 d8 d1 d6) by trivial. - replace (fun (i : {n' : nat | n' < p}) (j : {m' : nat | m' < n}) => - (c * - (@matrix_mult floatish_R p m n - (fun (i0 : {n' : nat | (n' < p)%nat}) - (j0 : {m' : nat | (m' < m)%nat}) => - d0 j0 i0) grad) i j)%R) with - (matrix_mult (fun (i : {n' : nat | n' < p}) (j : {m' : nat | m' < m}) => d0 j i) - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (c * grad i j)%R)) in IHdf2. - -- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d3 d2) = - Some (c * subvar (s, DTfloat) d4 d1)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d7 d5) = - Some (c * subvar (s, DTfloat) d8 d6)%R) by - (apply IHdf2; trivial; discriminate). - inversion H13; inversion H14. - rewrite H16; rewrite H17; lra. - -- unfold matrix_mult. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vsum_mult; f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - simpl; lra. - * unfold matrix_mult. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vsum_mult; f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - simpl; lra. - + now rewrite H6 in H0. - + rewrite H2 in H; rewrite H1 in H. - now rewrite H5 in H. - - Case "VectorPlus"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - specialize (IHdf1 grad grad_env1 grad_env2); intros. - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i => (c * grad i)%R)); intros. - rewrite H3 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros. - rewrite H4 in H0; simpl in H0. - + specialize (IHdf2 grad d1 d2). - rewrite H3 in IHdf1; rewrite H4 in IHdf1. - assert (vartlookup d1 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1). - case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2). - case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros. - rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d1 (fun i => (c * grad i)%R)) - ; [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d2 grad); [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d1 d5 d0 d3) by trivial. - rewrite (split_subvar d2 d6 d d4) by trivial. - rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by - (apply IHdf2; trivial; discriminate). - inversion H11; inversion H12. - rewrite H14; rewrite H15; lra. - + now rewrite H4 in H0. - + now rewrite H3 in H. - - Case "VectorMinus"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - specialize (IHdf1 grad grad_env1 grad_env2); intros. - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i => (c * grad i )%R)) - ; intros. - rewrite H3 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros. - rewrite H4 in H0; simpl in H0. - + specialize (IHdf2 (fun i => (- grad i)%R) d1 d2). - rewrite H3 in IHdf1; rewrite H4 in IHdf1. - assert (vartlookup d1 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1). - case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2). - case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros. - rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d1 (fun i => (- (c * grad i))%R)) - ; [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d2 (fun i => (- grad i)%R)); [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d1 d5 d0 d3) by trivial. - rewrite (split_subvar d2 d6 d d4) by trivial. - replace (fun i => (c * - grad i)%R) with (fun i => (-( c * grad i))%R) in IHdf2. - * rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by - (apply IHdf2; trivial; discriminate). - inversion H11; inversion H12. - rewrite H14; rewrite H15; lra. - * apply FunctionalExtensionality.functional_extensionality; intros. - lra. - + now rewrite H4 in H0. - + now rewrite H3 in H. - - Case "MatrixPlus"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - specialize (IHdf1 grad grad_env1 grad_env2); intros. - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i j => (c * grad i j)%R)) - ; intros. - rewrite H3 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros. - rewrite H4 in H0; simpl in H0. - + specialize (IHdf2 grad d1 d2). - rewrite H3 in IHdf1; rewrite H4 in IHdf1. - assert (vartlookup d1 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1). - case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2). - case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros. - rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d1 (fun i j => (c * grad i j)%R)) - ; [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d2 grad); [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d1 d5 d0 d3) by trivial. - rewrite (split_subvar d2 d6 d d4) by trivial. - rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by - (apply IHdf2; trivial; discriminate). - inversion H11; inversion H12. - rewrite H14; rewrite H15; lra. - + now rewrite H4 in H0. - + now rewrite H3 in H. - - Case "MatrixMinus"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - specialize (IHdf1 grad grad_env1 grad_env2); intros. - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i j => (c * grad i j)%R)) - ; intros. - rewrite H3 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros. - rewrite H4 in H0; simpl in H0. - + specialize (IHdf2 (fun i j => (- grad i j)%R) d1 d2). - rewrite H3 in IHdf1; rewrite H4 in IHdf1. - assert (vartlookup d1 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1). - case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2). - case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros. - rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d1 (fun i j => (- (c * grad i j))%R)) - ; [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d2 (fun i j => (- grad i j)%R)); [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d1 d5 d0 d3) by trivial. - rewrite (split_subvar d2 d6 d d4) by trivial. - replace (fun i j => (c * - grad i j)%R) with - (fun i j => (-( c * grad i j))%R) in IHdf2. - * rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by - (apply IHdf2; trivial; discriminate). - inversion H11; inversion H12. - rewrite H14; rewrite H15; lra. - * apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - + now rewrite H4 in H0. - + now rewrite H3 in H. - - Case "VectorScalMult"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - case_eq ( df_eval σ df1); [|tauto]. - case_eq ( df_eval σ df2); [|tauto]; intros. - specialize (IHdf1 (vsum (fun j => (d j * grad j)%R)) grad_env1 grad_env2). - rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 - (vsum (fun j => (d j * (c * grad j))%R))) - - ; intros. - rewrite H1 in H; rewrite H2 in H; simpl in H. - rewrite H1 in H0; rewrite H2 in H0; simpl in H0. - rewrite H5 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 - (vsum (fun j => (d j * grad j)%R))) - ; intros. - rewrite H6 in H0; simpl in H0. - + specialize (IHdf2 (fun j => (grad j * d0)%R) d3 d4). - replace - (c * - (@vsum floatish_R _ - (fun (j : {n' : nat | (n' < n)%nat}) => - d j * grad j)))%R with - (vsum - (fun (j : {n' : nat | n' < n}) => - (d j * (c * grad j))%R)) in IHdf1. - * rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1). - case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2). - case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros. - rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d3 - (fun (j : {n' : nat | n' < n}) => - (d0 * (c * grad j))%R)) - ; [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d4 (fun j => (d0 * grad j)%R)) - ; [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d3 d7 d2 d5) by trivial. - rewrite (split_subvar d4 d8 d1 d6) by trivial. - replace (fun i => (c * (grad i * d0))%R) with - (fun j => (d0 * (c * grad j))%R) in IHdf2. - replace (fun j => (grad j * d0)%R) with (fun j => (d0 * grad j)%R) in IHdf2. - -- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d3 d2) = - Some (c * subvar (s, DTfloat) d4 d1)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d7 d5) = - Some (c * subvar (s, DTfloat) d8 d6)%R) by - (apply IHdf2; trivial; discriminate). - inversion H13; inversion H14. - rewrite H16; rewrite H17; lra. - -- apply FunctionalExtensionality.functional_extensionality; intros. - lra. - -- apply FunctionalExtensionality.functional_extensionality; intros. - lra. - * rewrite vsum_mult; f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - + now rewrite H6 in H0. - + rewrite H2 in H; rewrite H1 in H. - now rewrite H5 in H. - - Case "MatrixScalMult"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - case_eq ( df_eval σ df1); [|tauto]. - case_eq ( df_eval σ df2); [|tauto]; intros. - specialize (IHdf1 (msum (fun i j => (d i j * grad i j)%R)) grad_env1 grad_env2). - rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 - (msum (fun i j => (d i j * (c * grad i j))%R))) - - ; intros. - rewrite H1, H2, H5 in H; simpl in H. - rewrite H1, H2 in H0; simpl in H0. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 - (msum (fun i j => (d i j * grad i j)%R))) - ; intros. - rewrite H6 in H0; simpl in H0. - + specialize (IHdf2 (fun i j => (grad i j * d0)%R) d3 d4). - replace - (c * - (@msum floatish_R _ _ - (fun (i : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) => - d i j * grad i j)))%R with - - (msum - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (d i j * (c * grad i j))%R)) in IHdf1. - * rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1). - case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2). - case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros. - rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d3 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (c * grad i j * d0)%R)) - ; [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d4 (fun i j => (grad i j * d0)%R)) - ; [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d3 d7 d2 d5) by trivial. - rewrite (split_subvar d4 d8 d1 d6) by trivial. - replace (fun i j => (c * (grad i j * d0))%R) with - (fun i j => (c * grad i j * d0)%R) in IHdf2. - -- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d3 d2) = - Some (c * subvar (s, DTfloat) d4 d1)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d7 d5) = - Some (c * subvar (s, DTfloat) d8 d6)%R) by - (apply IHdf2; trivial; discriminate). - inversion H13; inversion H14. - rewrite H16; rewrite H17; lra. - -- apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - * unfold msum. - rewrite vsum_mult; f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vmap_nth. - rewrite vmap_nth. - rewrite vsum_mult. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - + now rewrite H6 in H0. - + rewrite H2 in H; rewrite H1 in H. - now rewrite H5 in H. - - Case "VectorApply"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - simpl in *. - case_eq (df_eval σ df2); [ | tauto]. - intros. - rewrite H3 in H; rewrite H3 in H0; simpl in *. - match_option_in H0; [|tauto]. - specialize (IHdf1 v0 grad_env1 grad_env2). - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in IHdf1. - match_option_in H; [|tauto]. - vectoro_assert_forall_in eqq i. - apply vectoro_to_ovector_forall_some_f; trivial. - vectoro_assert_forall_in eqq0 i. - apply vectoro_to_ovector_forall_some_f; trivial. - assert (v1 = (fun i => (c * v0 i)%R)). - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H4 x). - specialize (H5 x). - rewrite vmap_nth in H4; simpl in H4. - rewrite vmap_nth in H5; simpl in H5. - match_option_in H4. - match_option_in H5. - inversion H4; inversion H5; subst. - assert (Some d2 = Some d3). - rewrite <- eqq1. - rewrite <- eqq2; trivial. - inversion H6; subst; lra. - subst. - apply IHdf1; trivial; discriminate. - - Case "MatrixApply"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - simpl in *. - case_eq (df_eval σ df2); [ | tauto]. - intros. - rewrite H3 in H; rewrite H3 in H0; simpl in *. - match_option_in H0; [|tauto]. - match_option_in H; [|tauto]. - unfold matrixo_to_omatrix in eqq. - unfold matrixo_to_omatrix in eqq0. - vectoro_assert_forall_in eqq i. - apply vectoro_to_ovector_forall_some_f; trivial. - vectoro_assert_forall_in eqq0 i. - apply vectoro_to_ovector_forall_some_f; trivial. - assert (m1 = (fun i j => (c * m0 i j)%R)). - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H4 x); specialize (H5 x); simpl in H4; simpl in H5. - vectoro_assert_forall_in H4 j. - apply vectoro_to_ovector_forall_some_f; trivial. - vectoro_assert_forall_in H5 j. - apply vectoro_to_ovector_forall_some_f; trivial. - specialize (H6 x0); specialize (H7 x0). - unfold mmap in H6; unfold mmap in H7. - rewrite vmap_nth in H6; rewrite vmap_nth in H6; simpl in H6. - rewrite vmap_nth in H7; rewrite vmap_nth in H7; simpl in H7. - match_case_in H6; intros. - rewrite H8 in H6; simpl in H6. - match_case_in H7; intros. - rewrite H9 in H7; simpl in H7. - match_option_in H6. - match_option_in H7. - inversion H6; inversion H7. - unfold matrix_zip in H8. - unfold matrix_zip in H9. - rewrite vmap_nth in H8. - rewrite vmap_nth in H9. - unfold vector_zip in H8. - unfold vector_zip in H9. - inversion H8; subst r r0. - inversion H9; subst r1 r2. - assert (Some d2 = Some d3). - rewrite <- eqq1; rewrite <- eqq2; trivial. - inversion H10; subst; lra. - specialize (IHdf1 m0 grad_env1 grad_env2). - rewrite H1, H2 in IHdf1; simpl in IHdf1. - subst m1. - apply IHdf1; trivial; discriminate. - - Case "VLossfun"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - simpl in *. - case_eq (df_eval σ df2); [ | tauto]. - intros. - rewrite H3 in H; rewrite H3 in H0; simpl in *. - match_option_in H0; [|tauto]. - match_option_in H; [|tauto]. - vectoro_assert_forall_in eqq i. - apply vectoro_to_ovector_forall_some_f; trivial. - vectoro_assert_forall_in eqq0 i. - apply vectoro_to_ovector_forall_some_f; trivial. - assert (v0 = (fun i => (c * v i)%R)). - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H4 x); specialize (H5 x). - rewrite vmap_nth in H4; simpl in H4. - rewrite vmap_nth in H5; simpl in H5. - match_option_in H4. - match_option_in H5. - assert (Some d2 = Some d3). - rewrite <- eqq1; rewrite <- eqq2; trivial. - inversion H6; subst. - inversion H4; inversion H5; lra. - subst. - specialize (IHdf1 v grad_env1 grad_env2). - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in IHdf1. - apply IHdf1; trivial; discriminate. - - Case "MLossfun"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - simpl in *. - case_eq (df_eval σ df2); [ | tauto]. - intros. - rewrite H3 in H; rewrite H3 in H0; simpl in *. - match_option_in H0; [|tauto]. - match_option_in H; [|tauto]. - unfold matrixo_to_omatrix in eqq. - unfold matrixo_to_omatrix in eqq0. - vectoro_assert_forall_in eqq i. - apply vectoro_to_ovector_forall_some_f; trivial. - vectoro_assert_forall_in eqq0 i. - apply vectoro_to_ovector_forall_some_f; trivial. - assert (m1 = (fun i j => (c * m0 i j)%R)). - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H4 x); specialize (H5 x); simpl in H4; simpl in H5. - vectoro_assert_forall_in H4 j. - apply vectoro_to_ovector_forall_some_f; trivial. - vectoro_assert_forall_in H5 j. - apply vectoro_to_ovector_forall_some_f; trivial. - specialize (H6 x0); specialize (H7 x0). - unfold mmap in H6; unfold mmap in H7. - rewrite vmap_nth in H6; rewrite vmap_nth in H6; simpl in H6. - rewrite vmap_nth in H7; rewrite vmap_nth in H7; simpl in H7. - match_destr_in H6. - match_option_in H6. - match_option_in H7. - assert (Some d2 = Some d3). - rewrite <- eqq1; rewrite <- eqq2; trivial. - inversion H8; subst. - inversion H6; inversion H7. - lra. - rewrite H6. - specialize (IHdf1 m0 grad_env1 grad_env2). - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in IHdf1. - subst. - apply IHdf1; trivial; discriminate. - Qed. - - Ltac simpl_closed_backprop := - match goal with - | [|- context [ - match df_eval_backprop_deriv ?σ ?df1 ?grad_env1 ?grad with - | Some _ => _ - | None => _ - end]] => case_eq (df_eval_backprop_deriv σ df1 grad_env1 grad) - ; [let env := fresh "env" in let eqq := fresh "eqq" in intros env eqq | - let eqq := fresh "eqq" in - intros eqq; - eelim backprop_deriv_fully_closed_not_none; [clear eqq | eapply eqq]; trivial - ] - end. - - Ltac simpler2 := - trivial; - repeat - match goal with - | [ |- Some _ <> None ] => congruence - | [ |- None <> Some _ ] => congruence - - | [H:vartlookup ?grad_env ?a <> None - |- context [vartlookup ?grad_env ?a]] => - case_eq (vartlookup grad_env a); [ - let d := fresh "val" in - let eqq := fresh "eqq" in - intros d eqq; simpl in d - | intros ?; eelim H; solve[auto]] - | [H: df_eval_backprop_deriv ?σ ?df1 ?grad_env1 _ = Some ?grad_env2 - |- context [vartlookup ?grad_env2 ?a]] => - case_eq (vartlookup grad_env2 a); [ - let d := fresh "val" in - let eqq := fresh "eqq" in - intros d eqq; simpl in d - | let eqq := fresh "eqq" in - intros eqq; eelim df_eval_backprop_deriv_preserves_lookup_not_none; [apply H | idtac | apply eqq]; solve[auto] - ] - - | [H:vartlookup ?grad_env ?a <> None - |- context [vartlookup ?grad_env ?a]] => - case_eq (vartlookup grad_env a); [ - let d := fresh "val" in - let eqq := fresh "eqq" in - intros d eqq; simpl in d - | intros ?; eelim H; solve[auto || congruence]] - | [H: df_eval_backprop_deriv ?σ ?df1 ?grad_env1 _ = Some ?grad_env2 - |- context [vartlookup ?grad_env2 ?a]] => - case_eq (vartlookup grad_env2 a); [ - let d := fresh "val" in - let eqq := fresh "eqq" in - intros d eqq; simpl in d - | let eqq := fresh "eqq" in - intros eqq; eelim df_eval_backprop_deriv_preserves_lookup_not_none; [apply H | idtac | apply eqq]; solve[auto || congruence] - ] - | [H:vartlookup ?grad_env ?a <> None, - H2:context [match vartlookup ?grad_env ?a with | _ => _ end] |- _] => - case_eq (vartlookup grad_env a); [ - let d := fresh "val" in - let eqq := fresh "eqq" in - intros d eqq; simpl in d; rewrite eqq in H2 - | intros ?; eelim H; solve[auto || congruence]] - | [H: df_eval_backprop_deriv ?σ ?df1 ?grad_env1 _ = Some ?grad_env2, - H2: context [match vartlookup ?grad_env2 ?a with _ => _ end] |- _] => - case_eq (vartlookup grad_env2 a); [ - let d := fresh "val" in - let eqq := fresh "eqq" in - intros d eqq; simpl in d; rewrite eqq in H2 - | let eqq := fresh "eqq" in - intros eqq; eelim df_eval_backprop_deriv_preserves_lookup_not_none; [apply H | idtac | apply eqq]; solve[auto || congruence]] - - end. - - Lemma backprop_indep_env {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (s:SubVar) - (grad_env1 grad_env2:df_env) (grad : definition_function_types_interp T) : - let v := (s, DTfloat) in - vartlookup grad_env1 v <> None -> - vartlookup grad_env2 v <> None -> - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> - df_eval_backprop_delta σ df v grad_env1 grad = - df_eval_backprop_delta σ df v grad_env2 grad. - Proof. - revert grad_env1 grad_env2. - unfold df_eval_backprop_delta. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case - ; simpl; intros grad_env1 grad_env2 neq1 neq2; intros. - - Case "Number"%string. - match_option; [|tauto]. - match_option; [|tauto]. - unfold snd in *. - f_equal. - unfold subvar; simpl. - rewrite eqq, eqq0. - lra. - - Case "Constant"%string. - match_option; [|tauto]. - match_option; [|tauto]. - unfold snd in *. - f_equal. - unfold subvar; simpl. - rewrite eqq, eqq0. - lra. - - Case "DVector"%string. - rewrite vforall_forall in H0. - unfold two_vector_env_iter_alt. - match_option; [|tauto]. - match_option; [|tauto]. - unfold lift; simpl. - match_option. - + match_option. - f_equal. - apply (list_env_iter_subvar_env2 - s d d0 {n' : nat | n' < n} - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (grad i)) - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (grad i)) - (bounded_seq0 n) grad_env1 grad_env2 d1 d2); trivial. - * intros. - specialize (H i (grad i) env1 env2). - cut_to H; try congruence; eauto 3. - rewrite H1, H2, H3, H4 in H. - unfold lift in H. - now inversion H. - * intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H1); trivial. - * intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H1); trivial. - * assert (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (grad i)) (Some grad_env2) - (bounded_seq0 n) <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - + assert (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (grad i)) (Some grad_env1) - (bounded_seq0 n) <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - - Case "DMatrix"%string. - rewrite vforall_forall in H0. - unfold two_matrix_env_iter_alt. - match_option; [|tauto]. - match_option; [|tauto]. - unfold lift; simpl. - match_option. - + match_option. - * f_equal. - apply (list_env_iter_subvar_env2 - s d d0 {n' : nat | n' < n} - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (Some env) (bounded_seq0 m)) - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (Some env) (bounded_seq0 m)) - (bounded_seq0 n) grad_env1 grad_env2 d1 d2); trivial. - -- intros. - apply (list_env_iter_subvar_env2 - s v1 v2 {m' : nat | m' < m} - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (bounded_seq0 m) env1 env2 fenv1 fenv2); trivial. - ++ intros. - specialize (H i i0 (grad i i0) env0 env3). - cut_to H. - rewrite H5, H6, H7, H8 in H. - unfold lift in H. - now inversion H. - congruence. - congruence. - specialize (H0 i). - rewrite vforall_forall in H0. - apply (H0 i0). - ++ intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5); trivial. - ++ intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5); trivial. - -- intros. - apply (vartlookup_list_env_iter - s - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (bounded_seq0 m) env fenv); trivial. - intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H3); trivial. - -- intros. - apply (vartlookup_list_env_iter - s - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (bounded_seq0 m) env fenv); trivial. - intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H3); trivial. - * assert (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (Some env) (bounded_seq0 m)) (Some grad_env2) (bounded_seq0 n) - <> None). - apply list_env_iter_total_fun; intros. - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - specialize (H0 a). - rewrite vforall_forall in H0. - apply (H0 a0). - tauto. - + assert (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (Some env) (bounded_seq0 m)) (Some grad_env1) (bounded_seq0 n) - <> None). - apply list_env_iter_total_fun; intros. - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - specialize (H0 a). - rewrite vforall_forall in H0. - apply (H0 a0). - tauto. - - Case "Var"%string. - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - unfold lift, subvar; simpl. - destruct (v == (s,DTfloat)). - + invcs e. - unfold addvar; simpl. - rewrite eqq, H0. - rewrite lookup_update. - rewrite lookup_update. - f_equal; lra. - + assert (v<> (s, DTfloat)) by congruence. - case_eq (vartlookup grad_env1 v); intros. - * rewrite lookup_update_neq; trivial. - rewrite eqq. - case_eq (vartlookup grad_env2 v); intros. - -- rewrite lookup_update_neq; trivial. - rewrite H0; f_equal; lra. - -- rewrite H0; f_equal; lra. - * rewrite eqq. - case_eq (vartlookup grad_env2 v); intros. - -- rewrite lookup_update_neq; trivial. - rewrite H0; f_equal; lra. - -- rewrite H0; f_equal; lra. - - Case "Plus"%string. - destruct H. - unfold lift. - repeat simpl_closed_backprop. - simpler2. - f_equal. - specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1,eqq3,eqq4 in IHdf1. - specialize (IHdf2 grad env env1). - cut_to IHdf2; simpler2. - unfold lift in IHdf1; inversion IHdf1. - rewrite eqq2, eqq0 in IHdf2; unfold lift in IHdf2; inversion IHdf2. - rewrite (split_subvar env env0 val val1); trivial. - rewrite (split_subvar env1 env2 val0 val2); trivial. - rewrite H2, H3; lra. - - Case "Minus"%string. - destruct H. - unfold lift. - repeat simpl_closed_backprop. - simpler2. - f_equal. - specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1,eqq3,eqq4 in IHdf1. - specialize (IHdf2 (-grad)%R env env1). - cut_to IHdf2; simpler2. - unfold lift in IHdf1; invcs IHdf1. - rewrite eqq0,eqq2 in IHdf2; unfold lift in IHdf2; invcs IHdf2. - rewrite (split_subvar env env0 val val1); trivial. - rewrite (split_subvar env1 env2 val0 val2); trivial. - rewrite H2, H3; lra. - - Case "Times"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df1 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - unfold lift. - repeat simpl_closed_backprop. - simpler2. - f_equal. - specialize (IHdf1 (d1 * grad)%R grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq6,eqq2,eqq4 in IHdf1. - specialize (IHdf2 (d0 * grad)%R env env1). - cut_to IHdf2; simpler2. - unfold lift in IHdf1; invcs IHdf1. - rewrite eqq5, eqq3 in IHdf2; unfold lift in IHdf2; invcs IHdf2. - rewrite (split_subvar env env0 d val0); trivial. - rewrite (split_subvar env1 env2 val val1); trivial. - rewrite H4, H5; lra. - - Case "Divide"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df1 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - unfold lift. - repeat simpl_closed_backprop. - simpler2. - f_equal. - specialize (IHdf1 (grad/d1)%R grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq6,eqq2,eqq4 in IHdf1. - specialize (IHdf2 (-d0/(d1*d1) * grad)%R env env1). - cut_to IHdf2; simpler2. - unfold lift in IHdf1; inversion IHdf1. - rewrite eqq5, eqq3 in IHdf2; unfold lift in IHdf2; inversion IHdf2. - rewrite (split_subvar env env0 d val0); trivial. - rewrite (split_subvar env1 env2 val val1); trivial. - rewrite H4, H5; lra. - - Case "Square"%string. - match_option; [|tauto]. - assert (df_eval σ df <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - match_option; [|tauto]. - unfold lift. - repeat simpl_closed_backprop. - f_equal. - specialize (IHdf (2 * d0 * grad)%R grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1,eqq2,eqq3 in IHdf. - now unfold lift in IHdf; inversion IHdf. - - Case "Exp"%string. - match_option; [|tauto]. - assert (df_eval σ df <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - match_option; [|tauto]. - unfold lift. - repeat simpl_closed_backprop. - f_equal. - specialize (IHdf (grad * exp d0)%R grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1,eqq2, eqq3 in IHdf. - now unfold lift in IHdf; inversion IHdf. - - Case "Log"%string. - match_option; [|tauto]. - assert (df_eval σ df <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df grad_env1 (grad/d0)%R <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - unfold lift. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df grad_env2 (grad/d0)%R <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - specialize (IHdf (grad / d0)%R grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1 in IHdf. - now rewrite eqq2, eqq3 in IHdf; unfold lift in IHdf; inversion IHdf. - - Case "Abs"%string. - match_option; [|tauto]. - assert (df_eval σ df <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df grad_env1 (grad * sign d0)%R <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - unfold lift. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df grad_env2 (grad * sign d0)%R <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - specialize (IHdf (grad * sign d0)%R grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1 in IHdf. - now rewrite eqq2, eqq3 in IHdf; unfold lift in IHdf; inversion IHdf. - - Case "Sign"%string. - match_option; [|tauto]. - assert (df_eval σ df <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df grad_env1 (0)%R <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - unfold lift. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df grad_env2 (0)%R <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - specialize (IHdf (0)%R grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq0 in IHdf. - now rewrite eqq1, eqq2 in IHdf; unfold lift in IHdf; inversion IHdf. - - Case "PSign"%string. - match_option; [|tauto]. - assert (df_eval σ df <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df grad_env1 (0)%R <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - unfold lift. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df grad_env2 (0)%R <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - specialize (IHdf (0)%R grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq0 in IHdf. - now rewrite eqq1, eqq2 in IHdf; unfold lift in IHdf; inversion IHdf. - - Case "Max"%string. - destruct H. - specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H). - specialize (IHdf2 grad grad_env1 grad_env2 neq1 neq2 H0). - match_option; [|tauto]. - assert (df_eval σ df1 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - match_option; [|tauto]. - rewrite eqq, eqq2 in IHdf1. - rewrite eqq, eqq2 in IHdf2. - destruct (Rle_dec d0 d1); trivial. - - Case "VectorDot"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df1 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - specialize (IHdf1 (vmap (fun rv : R => (rv * grad)%R) d1) - grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, H3 in IHdf1. - assert (df_eval_backprop_deriv σ df1 grad_env1 - (vmap (fun rv : R => (rv * grad)%R) d1) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 - (vmap (fun rv : R => (rv * grad)%R) d1) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1). - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2). - specialize (IHdf2 (vmap (fun lv : R => (lv * grad)%R) d0) - d3 d4 H6 H7 H0). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - unfold lift. - assert (df_eval_backprop_deriv σ df2 d3 - (vmap (fun lv : R => (lv * grad)%R) d0) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d4 - (vmap (fun lv : R => (lv * grad)%R) d0) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - unfold lift in IHdf1. - rewrite eqq2, eqq3 in IHdf1; inversion IHdf1. - rewrite eqq6, eqq7 in IHdf2; inversion IHdf2. - rewrite (split_subvar d3 d7 d d5); trivial. - rewrite (split_subvar d4 d8 d2 d6); trivial. - now rewrite H11, H12. - - Case "VectorSum"%string. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (ConstVector n grad) grad_env1 grad_env2 neq1 neq2 H). - now rewrite eqq, eqq0 in IHdf. - - Case "MatrixSum"%string. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (ConstMatrix m n grad) grad_env1 grad_env2 neq1 neq2 H). - now rewrite eqq, eqq0 in IHdf. - - Case "VectorElem"%string. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (fun k : {n' : nat | n' < n} => if equiv_dec (` k) (` i) - then grad else 0%R) - grad_env1 grad_env2 neq1 neq2 H). - now rewrite eqq, eqq0 in IHdf. - - Case "MatrixElem"%string. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) => - if equiv_dec (` k1) (` i) then - if equiv_dec (` k2) (` j) then grad else 0%R else 0%R) - grad_env1 grad_env2 neq1 neq2 H). - now rewrite eqq, eqq0 in IHdf. - - Case "MatrixVectorMult"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df1 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - specialize (IHdf1 (fun (i : {n' : nat | n' < m}) - (j : {m' : nat | m' < n}) => (grad i * d1 j)%R) - grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, H3 in IHdf1. - assert (df_eval_backprop_deriv σ df1 grad_env1 - (fun (i : {n' : nat | n' < m}) - (j : {m' : nat | m' < n}) => (grad i * d1 j)%R) - <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 - (fun (i : {n' : nat | n' < m}) - (j : {m' : nat | m' < n}) => (grad i * d1 j)%R) - <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1). - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2). - specialize (IHdf2 - (matrix_vector_mult - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => d0 j i) grad) - d3 d4 H6 H7 H0). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - unfold lift. - assert (df_eval_backprop_deriv - σ df2 d3 - (matrix_vector_mult (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => d0 j i) - grad) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d4 - (matrix_vector_mult (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => d0 j i) - grad) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - unfold lift in IHdf1. - unfold lift in IHdf2. - rewrite eqq2, eqq3 in IHdf1; inversion IHdf1. - rewrite eqq6, eqq7 in IHdf2; inversion IHdf2. - rewrite (split_subvar d3 d7 d d5); trivial. - rewrite (split_subvar d4 d8 d2 d6); trivial. - now rewrite H11, H12. - - Case "MatrixVectorAdd"%string. - destruct H. - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - specialize (IHdf1 grad - grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, H1 in IHdf1. - assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); [intros | tauto]. - rewrite eqq0, H4 in IHdf1. - unfold lift in IHdf1. - inversion IHdf1. - assert (list_env_iter - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad i)) (Some d1) - (bounded_seq0 n) <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - match_option; [|tauto]. - assert (list_env_iter - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad i)) (Some d2) - (bounded_seq0 n) <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - match_option; [|tauto]. - unfold lift; f_equal. - assert (vartlookup d1 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1). - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2). - assert (vartlookup d3 (s, DTfloat) <> None). - apply - (vartlookup_list_env_iter - s - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad i)) - (bounded_seq0 n) d1); trivial. - intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H10 (s, DTfloat) H11). - assert (vartlookup d4 (s, DTfloat) <> None). - apply - (vartlookup_list_env_iter - s - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad i)) - (bounded_seq0 n) d2); trivial. - intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H11 (s, DTfloat) H12). - unfold subvar in IHdf1; simpl in IHdf1. - match_option_in IHdf1; [|tauto]. - match_option_in IHdf1; [|tauto]. - rewrite (split_subvar d1 d3 d d5); trivial. - rewrite (split_subvar d2 d4 d0 d6); trivial. - rewrite H6. - apply Rplus_eq_compat_r . - apply (list_env_iter_subvar_env2 - s d5 d6 {m' : nat | m' < n} - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad i)) - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad i)) - (bounded_seq0 n) d1 d2 d3 d4); trivial. - intros. - specialize (IHdf2 (transpose grad i) env1 env2). - rewrite H12, H13 in IHdf2. - cut_to IHdf2; trivial. - rewrite H14, H15 in IHdf2. - unfold lift in IHdf2. - now inversion IHdf2. - discriminate. - discriminate. - intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H12 (s, DTfloat) H13). - intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H12 (s, DTfloat) H13). - - Case "MatrixMult"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df1 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - specialize (IHdf1 - (matrix_mult grad (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < p}) => d1 j i)) - grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, H3 in IHdf1. - assert (df_eval_backprop_deriv σ df1 grad_env1 - (matrix_mult grad (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < p}) => d1 j i)) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 - (matrix_mult grad (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < p}) => d1 j i)) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - rewrite eqq2, eqq3 in IHdf1. - unfold lift in IHdf1; inversion IHdf1. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1). - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2). - specialize (IHdf2 - (matrix_mult (fun (i : {n' : nat | n' < p}) - (j : {m' : nat | m' < m}) => d0 j i) grad) - d3 d4 H6 H8 H0). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - unfold lift. - assert (df_eval_backprop_deriv - σ df2 d3 - (matrix_mult (fun (i : {n' : nat | n' < p}) - (j : {m' : nat | m' < m}) => d0 j i) grad) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv - σ df2 d4 - (matrix_mult (fun (i : {n' : nat | n' < p}) - (j : {m' : nat | m' < m}) => d0 j i) grad) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - unfold lift in IHdf1. - unfold lift in IHdf2. - rewrite eqq6, eqq7 in IHdf2; inversion IHdf2. - rewrite (split_subvar d3 d7 d d5); trivial. - rewrite (split_subvar d4 d8 d2 d6); trivial. - now rewrite H7, H12. - - Case "VectorPlus"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - unfold lift. - assert (df_eval_backprop_deriv σ df2 d0 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d2 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1 in IHdf1. - specialize (IHdf2 grad d0 d2). - assert (vartlookup d0 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1). - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq2). - specialize (IHdf2 H5 H6 H0). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - rewrite eqq0, eqq2 in IHdf1; unfold lift in IHdf1; inversion IHdf1. - rewrite eqq3, eqq4 in IHdf2; unfold lift in IHdf2; inversion IHdf2. - rewrite (split_subvar d0 d3 d d5); trivial. - rewrite (split_subvar d2 d4 d1 d6); trivial. - lra. - - Case "VectorMinus"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - unfold lift. - assert (df_eval_backprop_deriv - σ df2 d0 - (fun i : {n' : nat | n' < n} => (- grad i)%R) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d2 - (fun i : {n' : nat | n' < n} => (- grad i)%R) <> None) - by (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1 in IHdf1. - specialize (IHdf2 (fun i : {n' : nat | n' < n} => (- grad i)%R) d0 d2). - assert (vartlookup d0 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1). - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq2). - specialize (IHdf2 H5 H6 H0). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - rewrite eqq0, eqq2 in IHdf1; unfold lift in IHdf1; inversion IHdf1. - rewrite eqq3, eqq4 in IHdf2; unfold lift in IHdf2; inversion IHdf2. - rewrite (split_subvar d0 d3 d d5); trivial. - rewrite (split_subvar d2 d4 d1 d6); trivial. - lra. - - Case "MatrixPlus"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - unfold lift. - assert (df_eval_backprop_deriv σ df2 d0 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d2 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1 in IHdf1. - specialize (IHdf2 grad d0 d2). - assert (vartlookup d0 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1). - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq2). - specialize (IHdf2 H5 H6 H0). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - rewrite eqq0, eqq2 in IHdf1; unfold lift in IHdf1; inversion IHdf1. - rewrite eqq3, eqq4 in IHdf2; unfold lift in IHdf2; inversion IHdf2. - rewrite (split_subvar d0 d3 d d5); trivial. - rewrite (split_subvar d2 d4 d1 d6); trivial. - lra. - - Case "MatrixMinus"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - unfold lift. - assert (df_eval_backprop_deriv - σ df2 d0 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad i j)%R) - <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d2 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad i j)%R) - <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1 in IHdf1. - specialize (IHdf2 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad i j)%R) - d0 d2). - assert (vartlookup d0 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1). - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq2). - specialize (IHdf2 H5 H6 H0). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - rewrite eqq0, eqq2 in IHdf1; unfold lift in IHdf1; inversion IHdf1. - rewrite eqq3, eqq4 in IHdf2; unfold lift in IHdf2; inversion IHdf2. - rewrite (split_subvar d0 d3 d d5); trivial. - rewrite (split_subvar d2 d4 d1 d6); trivial. - lra. - - Case "VectorScalMult"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df1 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - specialize (IHdf1 - (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad j)%R)) - grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, H3 in IHdf1. - assert (df_eval_backprop_deriv - σ df1 grad_env1 - (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad j)%R)) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 - (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad j)%R)) - <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1). - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2). - specialize (IHdf2 (fun j : {n' : nat | n' < n} => (d0 * grad j)%R) - d3 d4 H6 H7 H0). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - unfold lift. - assert (df_eval_backprop_deriv σ df2 d3 - (fun j : {n' : nat | n' < n} => (d0 * grad j)%R) - <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d4 - (fun j : {n' : nat | n' < n} => (d0 * grad j)%R) - <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - unfold lift in IHdf1. - rewrite eqq2, eqq3 in IHdf1; inversion IHdf1. - rewrite eqq6, eqq7 in IHdf2; inversion IHdf2. - rewrite (split_subvar d3 d7 d d5); trivial. - rewrite (split_subvar d4 d8 d2 d6); trivial. - now rewrite H11, H12. - - Case "MatrixScalMult"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df1 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - specialize (IHdf1 - (msum - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (d1 i j * grad i j)%R)) - grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, H3 in IHdf1. - assert (df_eval_backprop_deriv - σ df1 grad_env1 - (msum - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (d1 i j * grad i j)%R)) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 - (msum - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (d1 i j * grad i j)%R)) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1). - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2). - specialize (IHdf2 (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad i j * d0)%R) - d3 d4 H6 H7 H0). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - unfold lift. - assert (df_eval_backprop_deriv σ df2 d3 - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad i j * d0)%R) - <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d4 - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad i j * d0)%R) - <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - unfold lift in IHdf1. - rewrite eqq2, eqq3 in IHdf1; inversion IHdf1. - rewrite eqq6, eqq7 in IHdf2; inversion IHdf2. - rewrite (split_subvar d3 d7 d d5); trivial. - rewrite (split_subvar d4 d8 d2 d6); trivial. - now rewrite H11, H12. - - Case "VectorApply"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - match_option. - specialize (IHdf1 v0 grad_env1 grad_env2 neq1 neq2 H0). - now rewrite eqq, H2 in IHdf1. - - Case "MatrixApply"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - match_option. - specialize (IHdf1 m0 grad_env1 grad_env2 neq1 neq2 H0). - now rewrite eqq, H2 in IHdf1. - - Case "VLossfun"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - match_option. - specialize (IHdf1 v grad_env1 grad_env2 neq1 neq2 H0). - now rewrite eqq, H2 in IHdf1. - - Case "MLossfun"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - match_option. - specialize (IHdf1 m0 grad_env1 grad_env2 neq1 neq2 H0). - now rewrite eqq, H2 in IHdf1. - Qed. - - Lemma backprop_exchange_order {T} (σ:df_env) (df1 df2 :DefinedFunction UnitAnn T) (s: SubVar) - (env:df_env) (grad1 grad2 : definition_function_types_interp T) : - let v := (s, DTfloat) in - vartlookup env v <> None -> - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df1 vl -> - fully_closed_over df2 vl -> - match - df_eval_backprop_deriv σ df1 env grad1, df_eval_backprop_deriv σ df2 env grad2, vartlookup env v with - | Some env1, Some env2, Some oval => - lift (fun e => subvar v e oval) - (df_eval_backprop_deriv σ df1 env2 grad1) = - lift (fun e => subvar v e oval) - (df_eval_backprop_deriv σ df2 env1 grad2) - | _, _, _ => True - end. - Proof. - intros. - do 3 match_option. - unfold lift. - do 2 match_option. - - assert (vartlookup d0 v <> None); simpler2. - assert (vartlookup d v <> None); simpler2. - case_eq (vartlookup d0 v); [intros|tauto]. - case_eq (vartlookup d v); [intros|tauto]. - f_equal. subst v. - rewrite (split_subvar d0 d2 d1 d4); trivial. - rewrite (split_subvar d d3 d1 d5); trivial. - generalize (backprop_indep_env σ df1 s env d0 grad1); intros. - generalize (backprop_indep_env σ df2 s env d grad2); intros. - simpl in H6; simpl in H7. - cut_to H6; trivial; try congruence. - cut_to H7; trivial; try congruence. - unfold df_eval_backprop_delta in *. - rewrite eqq1, H4 in H6. - rewrite eqq1, H5 in H7. - unfold lift in H6; simpl in H6. - unfold lift in H7; simpl in H7. - rewrite eqq, eqq2 in H6. - rewrite eqq0, eqq3 in H7. - inversion H6; inversion H7. - rewrite H9, H10; lra. - - assert (df_eval_backprop_deriv σ df2 d grad2 <> None). - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - - assert (df_eval_backprop_deriv σ df1 d0 grad1 <> None). - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - Qed. - - Lemma list_env_iter_id {A} (env : df_env) (l : list A) : - list_env_iter (fun (_ : A) (env : df_env) => Some env) - (Some env) l = Some env. - Proof. - now induction l. - Qed. - - Lemma backprop_grad_sum_list_env_iter {m} (σ:df_env) - (vecdf:Vector (DefinedFunction UnitAnn DTfloat) m) (s: SubVar) - (grad_env1 grad_env2 grad_env3:df_env) - (grad1 grad2 : (Vector float m)) - (val1 val2 val3 : float) - (l : list {m' | m' < m} ) - : - let v := (s, DTfloat) in - (forall (i:{m' | m' < m}), - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over (vecdf i) vl) -> - (forall (i:{m' | m' < m}) (env1 env2 env3 : df_env), - vartlookup env1 v <> None -> - vartlookup env2 v <> None -> - vartlookup env3 v <> None -> - - df_eval_backprop_delta σ (vecdf i) v env3 - (grad1 i + grad2 i)%R = - lift2 dfti_gen_plus - (df_eval_backprop_delta σ (vecdf i) v env1 (grad1 i)) - (df_eval_backprop_delta σ (vecdf i) v env2 (grad2 i))) -> - - vartlookup grad_env1 v = Some val1 -> - vartlookup grad_env2 v = Some val2 -> - vartlookup grad_env3 v = Some val3 -> - - lift (fun e : df_env => subvar v e val3) - (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (vecdf j) env (grad1 j + grad2 j)%R) - (Some grad_env3) l) = - lift2 dfti_gen_plus - (lift (fun e : df_env => subvar v e val1) - (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (vecdf j) env (grad1 j)) (Some grad_env1) l)) - (lift (fun e : df_env => subvar v e val2) - (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (vecdf j) env (grad2 j)) (Some grad_env2) l)). - Proof. - intros. - revert val1 val2 val3 grad_env1 grad_env2 grad_env3 H1 H2 H3. - induction l. - - intros. - simpl; f_equal. - unfold subvar; simpl. - rewrite H1,H2,H3. - lra. - - intros. - simpl. - unfold df_eval_backprop_delta in H0. - unfold lift, lift2. - assert (df_eval_backprop_deriv σ (vecdf a) grad_env3 (grad1 a + grad2 a)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H. - case_eq (df_eval_backprop_deriv σ (vecdf a) grad_env3 (grad1 a + grad2 a)%R) - ; [intros | tauto]. - assert (df_eval_backprop_deriv σ (vecdf a) grad_env1 (grad1 a)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H. - case_eq (df_eval_backprop_deriv σ (vecdf a) grad_env1 (grad1 a)%R) - ; [intros | tauto]. - assert (df_eval_backprop_deriv σ (vecdf a) grad_env2 (grad2 a)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H. - case_eq (df_eval_backprop_deriv σ (vecdf a) grad_env2 (grad2 a)%R) - ; [intros | tauto]. - - assert (vartlookup d v <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 v); congruence. - assert (vartlookup d0 v <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H7 v); congruence. - assert (vartlookup d1 v <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H9 v); congruence. - - case_eq (vartlookup d v); [intros v3 eq3 |tauto]. - case_eq (vartlookup d0 v); [intros v1 eq1 |tauto]. - case_eq (vartlookup d1 v); [intros v2 eq2 |tauto]. - - specialize (IHl v1 v2 v3 d0 d1 d eq1 eq2 eq3). - unfold lift, lift2. - match_option. - case_eq (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (vecdf j) env (grad1 j)) (Some d0) l). - case_eq (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (vecdf j) env (grad2 j)) (Some d1) l). - + intros. - rewrite eqq, H13, H14 in IHl. - unfold lift, lift2 in IHl; simpl in IHl. - f_equal. - subst v. - rewrite (split_subvar d d2 val3 v3); trivial. - rewrite (split_subvar d0 d4 val1 v1); trivial. - rewrite (split_subvar d1 d3 val2 v2); trivial. - inversion IHl. - rewrite H16. - specialize (H0 a grad_env1 grad_env2 grad_env3). - cut_to H0; try congruence. - rewrite H1,H2,H3 in H0. - rewrite H5,H7,H9 in H0. - unfold lift, lift2 in H0. - inversion H0. - rewrite H17; lra. - + intros. - assert (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (vecdf j) env (grad2 j)) (Some d1) l <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - apply H. - tauto. - + intros. - assert (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (vecdf j) env (grad1 j)) (Some d0) l <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - apply H. - tauto. - + assert (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (vecdf j) env (grad1 j + grad2 j)%R) (Some d) l <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - apply H. - tauto. - Qed. - - Lemma backprop_mat_grad_sum_list_env_iter {m n} (σ:df_env) - (df : DefinedFunction UnitAnn (DTVector n)) (s: SubVar) - (grad_env1 grad_env2 grad_env3:df_env) - (grad1 grad2 : (Matrix float m n)) - (val1 val2 val3 : float) - (l : list {m' | m' < m} ) - : - let v := (s, DTfloat) in - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> - (forall (i:{m' | m' < m}) (env1 env2 env3 : df_env), - vartlookup env1 v <> None -> - vartlookup env2 v <> None -> - vartlookup env3 v <> None -> - - df_eval_backprop_delta σ df v env3 (fun j => (grad1 i j + grad2 i j)%R) = - lift2 dfti_gen_plus - (df_eval_backprop_delta σ df v env1 (grad1 i)) - (df_eval_backprop_delta σ df v env2 (grad2 i))) -> - - vartlookup grad_env1 v = Some val1 -> - vartlookup grad_env2 v = Some val2 -> - vartlookup grad_env3 v = Some val3 -> - - lift (fun e : df_env => subvar v e val3) - (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ df env (fun k => (grad1 j k + grad2 j k)%R) ) - (Some grad_env3) l) = - lift2 dfti_gen_plus - (lift (fun e : df_env => subvar v e val1) - (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ df env (grad1 j)) (Some grad_env1) l)) - (lift (fun e : df_env => subvar v e val2) - (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ df env (grad2 j)) (Some grad_env2) l)). - Proof. - intros. - revert val1 val2 val3 grad_env1 grad_env2 grad_env3 H1 H2 H3. - induction l. - - intros. - simpl; f_equal. - unfold subvar; simpl. - rewrite H1,H2,H3. - lra. - - intros. - unfold df_eval_backprop_delta in *. - simpl. - unfold lift, lift2. - assert (df_eval_backprop_deriv σ df grad_env3 - (fun k : {n' : nat | n' < n} => (grad1 a k + grad2 a k)%R) <> None). - apply backprop_deriv_fully_closed_not_none. - apply H. - case_eq (df_eval_backprop_deriv σ df grad_env3 - (fun k : {n' : nat | n' < n} => (grad1 a k + grad2 a k)%R)) - ; [intros | tauto]. - assert (df_eval_backprop_deriv σ df grad_env1 (grad1 a)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H. - case_eq (df_eval_backprop_deriv σ df grad_env1 (grad1 a)%R) - ; [intros | tauto]. - assert (df_eval_backprop_deriv σ df grad_env2 (grad2 a)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H. - case_eq (df_eval_backprop_deriv σ df grad_env2 (grad2 a)%R) - ; [intros | tauto]. - - assert (vartlookup d v <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 v); congruence. - assert (vartlookup d0 v <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H7 v); congruence. - assert (vartlookup d1 v <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H9 v); congruence. - - case_eq (vartlookup d v); [intros v3 eq3 |tauto]. - case_eq (vartlookup d0 v); [intros v1 eq1 |tauto]. - case_eq (vartlookup d1 v); [intros v2 eq2 |tauto]. - - specialize (IHl v1 v2 v3 d0 d1 d eq1 eq2 eq3). - match_option. - case_eq (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ df env (grad1 j)) (Some d0) l). - case_eq (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ df env (grad2 j)) (Some d1) l). - + intros. - rewrite eqq, H13, H14 in IHl. - unfold lift, lift2 in IHl; simpl in IHl. - f_equal. - subst v. - rewrite (split_subvar d d2 val3 v3); trivial. - rewrite (split_subvar d0 d4 val1 v1); trivial. - rewrite (split_subvar d1 d3 val2 v2); trivial. - inversion IHl. - rewrite H16. - specialize (H0 a grad_env1 grad_env2 grad_env3). - cut_to H0; try congruence. - rewrite H1,H2,H3 in H0. - rewrite H5,H7,H9 in H0. - unfold lift, lift2 in H0. - inversion H0. - rewrite H17; lra. - + intros. - assert (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ df env (grad2 j)) (Some d1) l <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - apply H. - tauto. - + intros. - assert (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ df env (grad1 j)) (Some d0) l <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - apply H. - tauto. - + assert (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ df env - (fun k : {n' : nat | n' < n} => (grad1 j k + grad2 j k)%R)) - (Some d) l <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - apply H. - tauto. - Qed. - - Lemma matrix_zip_m_n {T} {m n} {i j} {m1 m2 : Matrix T m n} : - matrix_zip m1 m2 i j = (m1 i j, m2 i j). - Proof. - unfold matrix_zip. - rewrite vmap_nth; simpl. - now unfold vector_zip. - Qed. - - Lemma backprop_grad_sum {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (s: SubVar) - (grad_env1 grad_env2 grad_env3:df_env) - (grad1 grad2 : definition_function_types_interp T) : - let v := (s, DTfloat) in - vartlookup grad_env1 v <> None -> - vartlookup grad_env2 v <> None -> - vartlookup grad_env3 v <> None -> - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> - df_eval_backprop_delta σ df (s,DTfloat) grad_env3 (dfti_gen_plus grad1 grad2) = - lift2 dfti_gen_plus - (df_eval_backprop_delta σ df (s,DTfloat) grad_env1 grad1) - (df_eval_backprop_delta σ df (s,DTfloat) grad_env2 grad2). - Proof. - unfold df_eval_backprop_delta. - revert grad_env1 grad_env2 grad_env3. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case -(* - DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case -*) - ; simpl; intros. - - Case "Number"%string. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - unfold subvar; simpl. - rewrite eqq, eqq0, eqq1. - f_equal; lra. - - Case "Constant"%string. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - unfold subvar; simpl. - rewrite eqq, eqq0, eqq1. - f_equal; lra. - - Case "DVector"%string. - unfold two_vector_env_iter_alt in *. - rewrite vforall_forall in H3. - revert grad_env1 grad_env2 grad_env3 H0 H1 H2. - induction (bounded_seq0 n). - + intros. - simpl. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - unfold subvar; simpl. - rewrite eqq, eqq0, eqq1. - f_equal; lra. - + intros. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - unfold lift; simpl in *. - assert (df_eval_backprop_deriv σ (x a) grad_env3 (grad1 a + grad2 a)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - assert (df_eval_backprop_deriv σ (x a) grad_env1 (grad1 a)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - assert (df_eval_backprop_deriv σ (x a) grad_env2 (grad2 a)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - case_eq (df_eval_backprop_deriv σ (x a) grad_env3 (grad1 a + grad2 a)%R) - ; [intros|tauto]. - case_eq (df_eval_backprop_deriv σ (x a) grad_env1 (grad1 a)%R) - ; [intros|tauto]. - case_eq (df_eval_backprop_deriv σ (x a) grad_env2 (grad2 a)%R) - ; [intros|tauto]. - assert (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (grad1 i + grad2 i)%R) - (Some d2) l <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - apply (H3 a0). - match_option; [|tauto]. - assert (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (grad1 i)) (Some d3) - l <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - apply (H3 a0). - match_option; [|tauto]. - assert (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (grad2 i)) (Some d4) - l <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - apply (H3 a0). - match_option; [|tauto]. - unfold lift2; simpl. - - assert (vartlookup d2 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H7 (s, DTfloat) H2). - assert (vartlookup d3 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H8 (s, DTfloat) H0). - assert (vartlookup d4 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H9 (s, DTfloat) H1). - - case_eq (vartlookup d2 (s, DTfloat)); [intros|tauto]. - case_eq (vartlookup d3 (s, DTfloat)); [intros|tauto]. - case_eq (vartlookup d4 (s, DTfloat)); [intros|tauto]. - - rewrite (split_subvar d2 d5 d d8); trivial. - rewrite (split_subvar d3 d6 d0 d9); trivial. - rewrite (split_subvar d4 d7 d1 d10); trivial. - - f_equal. - - specialize (IHl d3 d4 d2 H14 H15 H13). - rewrite H16, H17, H18 in IHl. - rewrite eqq2, eqq3, eqq4 in IHl. - unfold lift, lift2 in IHl. - inversion IHl. - rewrite H20. - - specialize (H a (grad1 a) (grad2 a) grad_env1 grad_env2 grad_env3 H0 H1 H2). - specialize (H (H3 a)). - - rewrite eqq,eqq0,eqq1 in H. - rewrite H7, H8, H9 in H. - unfold lift, lift2 in H. - inversion H. - rewrite H21. - lra. - - Case "DMatrix"%string. - unfold two_matrix_env_iter_alt in *. - rewrite vforall_forall in H3. - revert grad_env1 grad_env2 grad_env3 H0 H1 H2. - induction (bounded_seq0 n); induction (bounded_seq0 m). - + intros. - simpl. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - unfold subvar; simpl. - rewrite eqq, eqq0, eqq1. - f_equal; lra. - + intros. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - simpl. - unfold lift, lift2; simpl; f_equal. - unfold subvar; simpl. - rewrite eqq, eqq0, eqq1; simpl; lra. - + intros. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - rewrite list_env_iter_id. - rewrite list_env_iter_id. - rewrite list_env_iter_id. - unfold subvar; simpl. - rewrite eqq, eqq0, eqq1; simpl; f_equal; lra. - + intros. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - simpl. - assert (df_eval_backprop_deriv σ (x a a0) grad_env3 (grad1 a a0 + grad2 a a0)%R <> None). - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a); rewrite vforall_forall in H3; apply H3. - assert (df_eval_backprop_deriv σ (x a a0) grad_env1 (grad1 a a0)%R <> None). - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a); rewrite vforall_forall in H3; apply H3. - assert (df_eval_backprop_deriv σ (x a a0) grad_env2 (grad2 a a0)%R <> None). - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a); rewrite vforall_forall in H3; apply H3. - case_eq (df_eval_backprop_deriv σ (x a a0) grad_env3 (grad1 a a0 + grad2 a a0)%R) - ; [intros|tauto]. - case_eq (df_eval_backprop_deriv σ (x a a0) grad_env1 (grad1 a a0)%R) - ; [intros|tauto]. - case_eq (df_eval_backprop_deriv σ (x a a0) grad_env2 (grad2 a a0)%R) - ; [intros|tauto]. - assert (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (x a j) env (grad1 a j + grad2 a j)%R) - (Some d2) l0 <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a); rewrite vforall_forall in H3; apply H3. - case_eq (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (x a j) env (grad1 a j + grad2 a j)%R) - (Some d2) l0 ); [intros | tauto]. - assert (list_env_iter - (fun (i : {n' : nat | n' < m}) (env : df_env) => - df_eval_backprop_deriv σ (x a i) env (grad1 a i)) (Some d3) - l0 <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a); rewrite vforall_forall in H3; apply H3. - case_eq (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (x a j) env (grad1 a j)%R) - (Some d3) l0 ); [intros | tauto]. - assert (list_env_iter - (fun (i : {n' : nat | n' < m}) (env : df_env) => - df_eval_backprop_deriv σ (x a i) env (grad2 a i)) (Some d4) - l0 <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a); rewrite vforall_forall in H3; apply H3. - case_eq (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (x a j) env (grad2 a j)%R) - (Some d4) l0 ); [intros | tauto]. - assert - (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad1 i j + grad2 i j)%R) - (df_eval_backprop_deriv σ (x i a0) env (grad1 i a0 + grad2 i a0)%R) l0) - (Some d5) l <> None). - apply list_env_iter_total_fun; intros. - assert (df_eval_backprop_deriv σ (x a1 a0) env0 (grad1 a1 a0 + grad2 a1 a0)%R <> None). - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a1); rewrite vforall_forall in H3; apply H3. - case_eq (df_eval_backprop_deriv σ (x a1 a0) env0 (grad1 a1 a0 + grad2 a1 a0)%R); [intros|tauto]. - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a1); rewrite vforall_forall in H3; apply H3. - unfold lift; simpl. - match_option; [|tauto]. - - assert - (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad1 i j)%R) - (df_eval_backprop_deriv σ (x i a0) env (grad1 i a0)%R) l0) - (Some d6) l <> None). - apply list_env_iter_total_fun; intros. - assert (df_eval_backprop_deriv σ (x a1 a0) env0 (grad1 a1 a0)%R <> None). - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a1); rewrite vforall_forall in H3; apply H3. - case_eq (df_eval_backprop_deriv σ (x a1 a0) env0 (grad1 a1 a0)%R); [intros|tauto]. - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a1); rewrite vforall_forall in H3; apply H3. - match_option; [|tauto]. - - assert - (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad2 i j)%R) - (df_eval_backprop_deriv σ (x i a0) env (grad2 i a0)%R) l0) - (Some d7) l <> None). - apply list_env_iter_total_fun; intros. - assert (df_eval_backprop_deriv σ (x a1 a0) env0 (grad2 a1 a0)%R <> None). - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a1); rewrite vforall_forall in H3; apply H3. - case_eq (df_eval_backprop_deriv σ (x a1 a0) env0 (grad2 a1 a0)%R); [intros|tauto]. - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a1); rewrite vforall_forall in H3; apply H3. - match_option; [|tauto]. - - unfold lift2; simpl; f_equal. - - assert (vartlookup d2 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H7 (s, DTfloat) H2). - assert (vartlookup d3 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H8 (s, DTfloat) H0). - assert (vartlookup d4 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H9 (s, DTfloat) H1). - - case_eq (vartlookup d2 (s, DTfloat)); [intros|tauto]. - case_eq (vartlookup d3 (s, DTfloat)); [intros|tauto]. - case_eq (vartlookup d4 (s, DTfloat)); [intros|tauto]. - - assert (Hc := H). - assert (H3c := H3). - - specialize (H a a0 (grad1 a a0) (grad2 a a0) grad_env1 grad_env2 grad_env3 - H0 H1 H2). - - specialize (H3 a). - rewrite vforall_forall in H3. - specialize (H (H3 a0)). - rewrite eqq, eqq0, eqq1 in H; simpl in H. - rewrite H7, H8, H9 in H; unfold lift, lift2 in H; simpl in H. - - assert (vartlookup d5 (s, DTfloat) <> None). - apply (vartlookup_list_env_iter - s (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (x a j) env (grad1 a j + grad2 a j)%R) - l0 d2 d5); trivial; intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H25 (s, DTfloat) H26). - case_eq (vartlookup d5 (s, DTfloat)); [intros|tauto]. - assert (vartlookup d6 (s, DTfloat) <> None). - apply (vartlookup_list_env_iter - s (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (x a j) env (grad1 a j)%R) - l0 d3 d6); trivial; intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H27 (s, DTfloat) H28). - case_eq (vartlookup d6 (s, DTfloat)); [intros|tauto]. - assert (vartlookup d7 (s, DTfloat) <> None). - apply (vartlookup_list_env_iter - s (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (x a j) env (grad2 a j)%R) - l0 d4 d7); trivial; intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H29 (s, DTfloat) H30). - case_eq (vartlookup d7 (s, DTfloat)); [intros|tauto]. - - specialize (IHl d6 d7 d5 H27 H29 H25). - rewrite H28, H30, H26 in IHl; simpl in IHl. - - rewrite eqq2,eqq3,eqq4 in IHl. - unfold lift, lift2 in IHl; simpl in IHl. - - rewrite (split_subvar d5 d8 d d14); trivial. - rewrite (split_subvar d6 d9 d0 d15); trivial. - rewrite (split_subvar d7 d10 d1 d16); trivial. - - rewrite (split_subvar d2 d5 d d11); trivial. - rewrite (split_subvar d3 d6 d0 d12); trivial. - rewrite (split_subvar d4 d7 d1 d13); trivial. - - inversion H. - inversion IHl. - rewrite H32, H33. - - generalize (backprop_grad_sum_list_env_iter - σ (x a) s d3 d4 d2 (grad1 a) (grad2 a) - d12 d13 d11 l0); intros. - specialize (H31 H3). - cut_to H31. - * rewrite H11,H13,H15 in H31. - unfold lift, lift2 in H31. - inversion H31. - rewrite H35; lra. - * intros. - unfold df_eval_backprop_delta. - specialize (Hc a i (grad1 a i) (grad2 a i) env1 env2 env3). - specialize (Hc H34 H35 H36). - specialize (Hc (H3 i)). - apply Hc. - * trivial. - * trivial. - * trivial. - - Case "Var"%string. - match_option; [|tauto]. - case_eq (vartlookup grad_env1 (s, DTfloat)); [intros| tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros| tauto]. - destruct (v == (s,DTfloat)). - + invcs e. - rewrite eqq, H3, H4; simpl. - f_equal. - rewrite subvar_addvar_scalar_eq; trivial. - rewrite subvar_addvar_scalar_eq; trivial. - rewrite subvar_addvar_scalar_eq; trivial. - + assert (v<> (s, DTfloat)) by congruence. - match_option. - * match_option. - -- match_option; unfold lift, lift2; simpl; f_equal. - ++ rewrite subvar_addvar_scalar_neq; trivial. - rewrite subvar_addvar_scalar_neq; trivial. - rewrite subvar_addvar_scalar_neq; trivial. - lra. - ++ rewrite subvar_addvar_scalar_neq; trivial. - rewrite subvar_addvar_scalar_neq; trivial. - unfold subvar; simpl; rewrite H4; lra. - -- unfold lift, lift2; simpl; f_equal. - rewrite subvar_addvar_scalar_neq; trivial. - case_eq (vartlookup grad_env2 v); intros. - ++ rewrite subvar_addvar_scalar_neq; trivial. - f_equal; unfold subvar; simpl. - rewrite H3; lra. - ++ unfold subvar; simpl. - rewrite H3, H4. - f_equal; lra. - * match_option. - -- match_option; unfold lift, lift2; simpl; f_equal. - ++ rewrite subvar_addvar_scalar_neq; trivial. - rewrite subvar_addvar_scalar_neq; trivial. - unfold subvar; simpl. - rewrite eqq; lra. - ++ rewrite subvar_addvar_scalar_neq; trivial. - unfold subvar; simpl. - rewrite eqq, H4; lra. - -- match_option; unfold lift, lift2; simpl; f_equal. - ++ rewrite subvar_addvar_scalar_neq; trivial. - unfold subvar; simpl. - rewrite eqq, H3; lra. - ++ unfold subvar; simpl. - rewrite eqq, H3, H4; lra. - - Case "Plus"%string. - destruct H2. - unfold lift, lift2. - repeat simpl_closed_backprop. - simpler2. - specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 grad1 grad2 env1 env3 env). - simpl in IHdf1. - rewrite eqq5, eqq6, eqq7 in IHdf1. - rewrite eqq, eqq1, eqq3 in IHdf1. - unfold lift, lift2 in IHdf1; simpl in IHdf1. - inversion IHdf1. - cut_to IHdf2; simpler2. - simpl in IHdf2. - rewrite eqq0, eqq2, eqq4 in IHdf2. - unfold lift, lift2 in IHdf2. - inversion IHdf2. - f_equal. - rewrite (split_subvar env env0 val val2); trivial. - rewrite (split_subvar env1 env2 val0 val3); trivial. - rewrite (split_subvar env3 env4 val1 val4); trivial. - rewrite H5, H6; lra. - - Case "Minus"%string. - destruct H2. - unfold lift, lift2. - repeat simpl_closed_backprop. - simpler2. - specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 (-grad1)%R (-grad2)%R env1 env3 env). - simpl in IHdf1. - rewrite eqq5, eqq6, eqq7 in IHdf1. - rewrite eqq, eqq1, eqq3 in IHdf1. - unfold lift, lift2 in IHdf1; simpl in IHdf1. - inversion IHdf1. - cut_to IHdf2; simpler2. - simpl in IHdf2. - replace (- grad1 + - grad2)%R with (- (grad1 + grad2))%R in IHdf2 by lra. - rewrite eqq0, eqq2, eqq4 in IHdf2. - unfold lift, lift2 in IHdf2; simpl in IHdf2. - inversion IHdf2. - rewrite (split_subvar env env0 val val2); trivial. - rewrite (split_subvar env1 env2 val0 val3); trivial. - rewrite (split_subvar env3 env4 val1 val4); trivial. - f_equal; rewrite H5, H6; lra. - - Case "Times"%string. - destruct H2. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df1); intros. - specialize (H4 H2). - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H5 H3). - match_option; [|tauto]. - unfold lift, lift2. - repeat simpl_closed_backprop. - simpler2. - specialize (IHdf1 (d1 * grad1)%R (d1 * grad2)%R grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 (d0 * grad1)%R (d0 * grad2)%R env1 env3 env). - rewrite eqq8, eqq9, eqq in IHdf1. - simpl in IHdf1. - replace (d1 *grad1 + d1*grad2)%R with (d1 * (grad1 + grad2))%R in IHdf1 by lra. - rewrite eqq2, eqq4, eqq6 in IHdf1. - unfold lift, lift2 in IHdf1; inversion IHdf1. - cut_to IHdf2; simpler2. - simpl in IHdf2. - replace (d0 *grad1 + d0 *grad2)%R with (d0 *(grad1 + grad2))%R in IHdf2 by lra. - rewrite eqq3,eqq5,eqq7 in IHdf2. - unfold lift, lift2 in IHdf2; inversion IHdf2. - f_equal. - rewrite (split_subvar env env0 d val1); trivial. - rewrite (split_subvar env1 env2 val val2); trivial. - rewrite (split_subvar env3 env4 val0 val3); trivial. - rewrite H7, H8; lra. - - Case "Divide"%string. - destruct H2. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df1); intros. - specialize (H4 H2). - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H5 H3). - match_option; [|tauto]. - unfold lift, lift2. - repeat simpl_closed_backprop. - simpler2. - specialize (IHdf1 (grad1/d1)%R (grad2/d1)%R grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 (-d0/(d1*d1) * grad1)%R (-d0/(d1*d1) * grad2)%R env1 env3 env). - rewrite eqq, eqq8, eqq9 in IHdf1. - simpl in IHdf1. - replace (grad1/d1 + grad2/d1)%R with ((grad1 + grad2)/d1)%R in IHdf1 by lra. - rewrite eqq2, eqq4, eqq6 in IHdf1. - unfold lift, lift2 in IHdf1; inversion IHdf1. - cut_to IHdf2; simpler2. - simpl in IHdf2. - replace (-d0/(d1*d1) *grad1 + -d0/(d1*d1) *grad2)%R with (-d0/(d1*d1) *(grad1 + grad2))%R in IHdf2 by lra. - rewrite eqq3, eqq5, eqq7 in IHdf2. - unfold lift, lift2 in IHdf2; inversion IHdf2. - f_equal. - rewrite (split_subvar env env0 d val1); trivial. - rewrite (split_subvar env1 env2 val val2); trivial. - rewrite (split_subvar env3 env4 val0 val3); trivial. - rewrite H7, H8; lra. - - Case "Square"%string. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df); intros. - specialize (H3 H2). - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (2 * d0 *grad1)%R (2 * d0 *grad2)%R grad_env1 grad_env2 grad_env3). - specialize (IHdf H H0 H1 H2). - rewrite eqq, eqq1, eqq2 in IHdf. - simpl in IHdf. - replace (2 * d0 * grad1 + 2 * d0 * grad2)%R with (2 * d0 * (grad1 + grad2))%R in IHdf by lra. - apply IHdf. - - Case "Exp"%string. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df); intros. - specialize (H3 H2). - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (grad1 * exp d0)%R (grad2*exp d0)%R grad_env1 grad_env2 grad_env3). - specialize (IHdf H H0 H1 H2). - rewrite eqq, eqq1, eqq2 in IHdf. - simpl in IHdf. - replace (grad1 * exp d0 + grad2 * exp d0)%R with ((grad1 + grad2) * exp d0)%R in IHdf by lra. - apply IHdf. - - Case "Log"%string. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df); intros. - specialize (H3 H2). - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (grad1/d0)%R (grad2/d0)%R grad_env1 grad_env2 grad_env3). - specialize (IHdf H H0 H1 H2). - rewrite eqq, eqq1, eqq2 in IHdf. - simpl in IHdf. - replace (grad1/d0 + grad2/d0)%R with ((grad1 + grad2) / d0)%R in IHdf by lra. - apply IHdf. - - Case "Abs"%string. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df); intros. - specialize (H3 H2). - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (grad1*sign d0)%R (grad2*sign d0)%R grad_env1 grad_env2 grad_env3). - specialize (IHdf H H0 H1 H2). - rewrite eqq, eqq1, eqq2 in IHdf. - simpl in IHdf. - replace (grad1*sign d0 + grad2*sign d0)%R with ((grad1 + grad2) * sign d0)%R in IHdf by lra. - apply IHdf. - - Case "Sign"%string. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (0)%R (0)%R grad_env1 grad_env2 grad_env3). - specialize (IHdf H H0 H1 H2). - rewrite eqq, eqq0, eqq1 in IHdf. - simpl in IHdf. - replace (0 + 0)%R with 0%R in IHdf by lra. - apply IHdf. - - Case "PSign"%string. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (0)%R (0)%R grad_env1 grad_env2 grad_env3). - specialize (IHdf H H0 H1 H2). - rewrite eqq, eqq0, eqq1 in IHdf. - simpl in IHdf. - replace (0 + 0)%R with 0%R in IHdf by lra. - apply IHdf. - - Case "Max"%string. - destruct H2. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df1); intros. - specialize (H4 H2). - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H5 H3). - match_option; [|tauto]. - destruct (Rle_dec d0 d1). - + match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf2 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H3). - rewrite eqq, eqq2, eqq3 in IHdf2; simpl in IHdf2. - apply IHdf2. - + match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2). - rewrite eqq, eqq2, eqq3 in IHdf1; simpl in IHdf1. - apply IHdf1. - - Case "VectorDot"%string. - destruct H2. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df1); intros. - specialize (H4 H2). - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H5 H3). - match_option; [|tauto]. - unfold lift, lift2. - repeat simpl_closed_backprop. - simpler2. - specialize (IHdf1 - (vmap (fun rv : R => (rv * grad1)%R) d1) - (vmap (fun rv : R => (rv * grad2)%R) d1) - grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 - (vmap (fun lv : R => (lv * grad1)%R) d0) - (vmap (fun lv : R => (lv * grad2)%R) d0) - env1 env3 env). - rewrite eqq, eqq8, eqq9 in IHdf1. - simpl in IHdf1. - replace - (fun i : {n' : nat | n' < n} => - (vmap (fun rv : R => rv * grad1) d1 i + vmap (fun rv : R => rv * grad2) d1 i)%R) - with - (vmap (fun rv : R => (rv * (grad1 + grad2))%R) d1) - in IHdf1. - + rewrite eqq2, eqq4, eqq6 in IHdf1. - unfold lift, lift2 in IHdf1; inversion IHdf1. - simpl in IHdf2. - cut_to IHdf2; simpler2. - replace (fun i : {n' : nat | n' < n} => - (vmap (fun lv : R => lv * grad1) d0 i + - vmap (fun lv : R => lv * grad2) d0 i)%R) with - (vmap (fun lv : R => (lv * (grad1 + grad2))%R) d0) in IHdf2. - * rewrite eqq3, eqq5, eqq7 in IHdf2. - unfold lift, lift2 in IHdf2; inversion IHdf2. - f_equal. - rewrite (split_subvar env env0 d val1); trivial. - rewrite (split_subvar env1 env2 val val2); trivial. - rewrite (split_subvar env3 env4 val0 val3); trivial. - rewrite H7, H8; lra. - * apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vmap_nth. - rewrite vmap_nth. - rewrite vmap_nth. - lra. - + apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vmap_nth. - rewrite vmap_nth. - rewrite vmap_nth. - lra. - - Case "VectorSum"%string. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (ConstVector n grad1) (ConstVector n grad2) grad_env1 grad_env2 grad_env3). - specialize (IHdf H H0 H1 H2). - rewrite eqq, eqq0, eqq1 in IHdf. - simpl in IHdf. - unfold ConstVector in IHdf. - unfold ConstVector. - apply IHdf. - - Case "MatrixSum"%string. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (ConstMatrix m n grad1) (ConstMatrix m n grad2) grad_env1 grad_env2 grad_env3). - specialize (IHdf H H0 H1 H2). - rewrite eqq, eqq0, eqq1 in IHdf. - simpl in IHdf. - unfold ConstMatrix in IHdf. - unfold ConstMatrix. - apply IHdf. - - Case "VectorElem"%string. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (fun k : {n' : nat | n' < n} => - if equiv_dec (` k) (` i) then grad1 else 0%R) - (fun k : {n' : nat | n' < n} => - if equiv_dec (` k) (` i) then grad2 else 0%R) - grad_env1 grad_env2 grad_env3 H H0 H1 H2). - rewrite eqq, eqq0, eqq1 in IHdf. - simpl in IHdf. - replace (fun k : {n' : nat | n' < n} => - if equiv_dec (` k) (` i) then (grad1 + grad2)%R else 0%R) with - (fun i0 : {n' : nat | n' < n} => - ((if equiv_dec (` i0) (` i) then grad1 else 0) + - (if equiv_dec (` i0) (` i) then grad2 else 0))%R). - apply IHdf. - apply FunctionalExtensionality.functional_extensionality; intros. - destruct (equiv_dec (` x) (` i)); lra. - - Case "MatrixElem"%string. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf - (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) => - if equiv_dec (` k1) (` i) - then if equiv_dec (` k2) (` j) then grad1 else 0%R - else 0%R) - (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) => - if equiv_dec (` k1) (` i) - then if equiv_dec (` k2) (` j) then grad2 else 0%R - else 0%R) - grad_env1 grad_env2 grad_env3 H H0 H1 H2). - rewrite eqq, eqq0, eqq1 in IHdf. - simpl in IHdf. - replace (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) => - if equiv_dec (` k1) (` i) - then if equiv_dec (` k2) (` j) then (grad1 + grad2)%R else 0%R - else 0%R) with - (fun (i0 : {n' : nat | n' < m}) (j0 : {m' : nat | m' < n}) => - ((if equiv_dec (` i0) (` i) - then if equiv_dec (` j0) (` j) then grad1 else 0 - else 0) + - (if equiv_dec (` i0) (` i) - then if equiv_dec (` j0) (` j) then grad2 else 0 - else 0))%R). - apply IHdf. - apply FunctionalExtensionality.functional_extensionality; intros. - destruct (equiv_dec (` x) (` i)). - + apply FunctionalExtensionality.functional_extensionality; intros. - destruct (equiv_dec (` x0) (` j)); lra. - + apply FunctionalExtensionality.functional_extensionality; intros. - lra. - - Case "MatrixVectorMult"%string. - destruct H2. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df1); intros. - specialize (H4 H2). - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H5 H3). - match_option; [|tauto]. - unfold lift, lift2. - repeat simpl_closed_backprop. - simpler2. - specialize (IHdf1 - (fun (i : {n' : nat | n' < m}) - (j : {m' : nat | m' < n}) => (grad1 i * d1 j)%R) - (fun (i : {n' : nat | n' < m}) - (j : {m' : nat | m' < n}) => (grad2 i * d1 j)%R) - grad_env1 grad_env2 grad_env3 H H0 H1 H2). - rewrite eqq, eqq8, eqq9 in IHdf1. - specialize (IHdf2 - (matrix_vector_mult - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => d0 j i) grad1) - (matrix_vector_mult - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => d0 j i) grad2) - env1 env3 env). - simpl in IHdf1. - replace - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (grad1 i * d1 j + grad2 i * d1 j)%R) with - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - ((grad1 i + grad2 i) * d1 j)%R) in IHdf1. - + rewrite eqq2, eqq4, eqq6 in IHdf1. - unfold lift, lift2 in IHdf1; inversion IHdf1. - cut_to IHdf2; simpler2. - simpl in IHdf2. - replace - (fun i : {n' : nat | n' < n} => - (matrix_vector_mult - (fun (i0 : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) => - d0 j i0) grad1 i + - matrix_vector_mult - (fun (i0 : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) => - d0 j i0) grad2 i)%R) with - (matrix_vector_mult - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => d0 j i) - (fun i : {n' : nat | n' < m} => (grad1 i + grad2 i)%R)) in IHdf2. - * rewrite eqq3, eqq5, eqq7 in IHdf2. - unfold lift, lift2 in IHdf2; inversion IHdf2. - f_equal. - rewrite (split_subvar env env0 d val1); trivial. - rewrite (split_subvar env1 env2 val val2); trivial. - rewrite (split_subvar env3 env4 val0 val3); trivial. - rewrite H7, H8; lra. - * apply FunctionalExtensionality.functional_extensionality; intros. - unfold matrix_vector_mult. - rewrite vsum_plus; f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - simpl; lra. - + apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - - Case "MatrixVectorAdd"%string. - destruct H2. - simpl; intros. - repeat simpl_closed_backprop. - simpler2. - specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2). - simpl in IHdf1. - rewrite eqq,eqq0,eqq1,eqq2,eqq3,eqq4 in IHdf1. - unfold lift, lift2 in IHdf1. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) H); intro. - case_eq (vartlookup env0 (s, DTfloat)); [intros|tauto]. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat) H0); intro. - case_eq (vartlookup env1 (s, DTfloat)); [intros|tauto]. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq (s, DTfloat) H1); intro. - case_eq (vartlookup env (s, DTfloat)); [intros|tauto]. - generalize (backprop_mat_grad_sum_list_env_iter - σ df2 s env0 env1 env (transpose grad1) (transpose grad2) - d d0 d1 (bounded_seq0 n)); intros. - simpl in H10. - specialize (H10 H3). - cut_to H10; trivial. - + do 3 match_option. - * rewrite eqq6, eqq7 in H10. - replace (fun (j : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env - (fun k : {n' : nat | n' < m} => ((@transpose R m n grad1 j k) + - (@transpose R m n grad2 j k))%R)) - with - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env - (transpose - (fun (i0 : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (grad1 i0 j + grad2 i0 j)%R) i)) in H10. - -- rewrite eqq5 in H10. - unfold lift, lift2 in H10. - unfold lift, lift2; f_equal. - inversion IHdf1; inversion H10. - rewrite (split_subvar env d2 val d1); trivial. - rewrite (split_subvar env0 d3 val0 d); trivial. - rewrite (split_subvar env1 d4 val1 d0); trivial. - rewrite H12, H13; lra. - -- apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - f_equal. - * assert (list_env_iter - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad2 i)) - (Some env1) (bounded_seq0 n) <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - * assert (list_env_iter - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad1 i)) - (Some env0) (bounded_seq0 n) <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - * assert (list_env_iter - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad1 i)) - (Some env0) (bounded_seq0 n) <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - * assert (list_env_iter - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env - (transpose - (fun (i0 : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (grad1 i0 j + grad2 i0 j)%R) i)) (Some env) (bounded_seq0 n) <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - + intros. - apply IHdf2; trivial. - - Case "MatrixMult"%string. - destruct H2. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df1); intros. - specialize (H4 H2). - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H5 H3). - match_option; [|tauto]. - unfold lift, lift2. - repeat simpl_closed_backprop. - simpler2. - specialize - (IHdf1 - (matrix_mult grad1 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < p}) => d1 j i)) - (matrix_mult grad2 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < p}) => d1 j i)) - grad_env1 grad_env2 grad_env3 H H0 H1 H2). - rewrite eqq, eqq8, eqq9 in IHdf1. - specialize (IHdf2 - (matrix_mult (fun (i : {n' : nat | n' < p}) - (j : {m' : nat | m' < m}) => d0 j i) - grad1) - (matrix_mult (fun (i : {n' : nat | n' < p}) - (j : {m' : nat | m' < m}) => d0 j i) - grad2) - env1 env3 env). - simpl in IHdf1. - replace - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < p}) => - (matrix_mult grad1 - (fun (i0 : {n' : nat | (n' < n)%nat}) (j0 : {m' : nat | (m' < p)%nat}) => - d1 j0 i0) i j + - matrix_mult grad2 - (fun (i0 : {n' : nat | (n' < n)%nat}) (j0 : {m' : nat | (m' < p)%nat}) => - d1 j0 i0) i j)%R) - with - (matrix_mult - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (grad1 i j + grad2 i j)%R) - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < p}) => d1 j i)) in IHdf1. - + rewrite eqq2, eqq4, eqq6 in IHdf1. - unfold lift, lift2 in IHdf1; simpl in IHdf1. - inversion IHdf1. - cut_to IHdf2; simpler2. - simpl in IHdf2. - rewrite eqq5, eqq7 in IHdf2. - replace - (fun (i : {n' : nat | n' < p}) (j : {m' : nat | m' < n}) => - (matrix_mult - (fun (i0 : {n' : nat | (n' < p)%nat}) (j0 : {m' : nat | (m' < m)%nat}) => - d0 j0 i0) grad1 i j + - matrix_mult - (fun (i0 : {n' : nat | (n' < p)%nat}) (j0 : {m' : nat | (m' < m)%nat}) => - d0 j0 i0) grad2 i j)%R) with - (matrix_mult (fun (i : {n' : nat | n' < p}) (j : {m' : nat | m' < m}) => d0 j i) - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (grad1 i j + grad2 i j)%R)) - in IHdf2. - * rewrite eqq3 in IHdf2. - unfold lift, lift2 in IHdf2; simpl in IHdf2. - inversion IHdf2. - f_equal. - rewrite (split_subvar env env0 d val1); trivial. - rewrite (split_subvar env1 env2 val val2); trivial. - rewrite (split_subvar env3 env4 val0 val3); trivial. - rewrite H7, H8; lra. - * unfold matrix_mult. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vsum_plus; f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - simpl; lra. - + unfold matrix_mult. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vsum_plus; f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - simpl; lra. - - Case "VectorPlus"%string. - destruct H2. - unfold lift, lift2. - repeat simpl_closed_backprop. - specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 grad1 grad2 env1 env3 env). - simpler2. - simpl in IHdf1. - rewrite eqq1, eqq3, eqq in IHdf1. - unfold lift, lift2 in IHdf1; simpl in IHdf1. - inversion IHdf1. - cut_to IHdf2; try congruence. - simpl in IHdf2. - rewrite eqq0, eqq2, eqq4 in IHdf2. - unfold lift, lift2 in IHdf2; inversion IHdf2. - f_equal. - rewrite (split_subvar env env0 val val5); trivial. - rewrite (split_subvar env1 env2 val0 val6); trivial. - rewrite (split_subvar env3 env4 val1 val7); trivial. - rewrite eqq5 in eqq8. - rewrite eqq6 in eqq9. - rewrite eqq7 in eqq10. - invcs eqq8; invcs eqq9; invcs eqq10. - rewrite H5, H6; lra. - - Case "VectorMinus"%string. - destruct H2. - unfold lift, lift2. - repeat simpl_closed_backprop. - specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 (fun i : {n' : nat | n' < n} => (- grad1 i)%R) - (fun i : {n' : nat | n' < n} => (- grad2 i)%R) - env1 env3 env). - simpler2. - simpl in IHdf1. - rewrite eqq, eqq1, eqq3 in IHdf1. - unfold lift, lift2 in IHdf1; simpl in IHdf1. - inversion IHdf1. - cut_to IHdf2; try congruence. - simpl in IHdf2. - replace (fun i : {n' : nat | n' < n} => (- grad1 i + - grad2 i)%R) with - (fun i : {n' : nat | n' < n} => (- (grad1 i + grad2 i))%R) in IHdf2. - rewrite eqq0, eqq2, eqq4 in IHdf2. - unfold lift, lift2 in IHdf2; simpl in IHdf2. - inversion IHdf2. - f_equal. - rewrite (split_subvar env env0 val val5); trivial. - rewrite (split_subvar env1 env2 val0 val6); trivial. - rewrite (split_subvar env3 env4 val1 val7); trivial. - rewrite eqq5 in eqq8; rewrite eqq6 in eqq9; rewrite eqq7 in eqq10. - invcs eqq8; invcs eqq9; invcs eqq10. - rewrite H5, H6; lra. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - - Case "MatrixPlus"%string. - destruct H2. - match_option; [|tauto]. - assert (df_eval_backprop_deriv - σ df1 grad_env3 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (grad1 i j + grad2 i j)%R) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env1 (grad1)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 (grad2)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 grad1 grad2 d2 d4 d0). - rewrite eqq, eqq1, eqq3 in IHdf1. - simpl in IHdf1. - rewrite eqq0, eqq2, eqq4 in IHdf1. - unfold lift, lift2 in IHdf1; simpl in IHdf1. - inversion IHdf1. - assert (vartlookup d2 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) H). - assert (vartlookup d4 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s, DTfloat) H0). - assert (vartlookup d0 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) H1). - specialize (IHdf2 H7 H9 H10 H3). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - simpl in IHdf2. - unfold lift, lift2; simpl. - assert (df_eval_backprop_deriv - σ df2 d0 - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad1 i j + grad2 i j)%R) <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d2 (grad1)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - case_eq (df_eval_backprop_deriv σ df2 d2 (grad1)%R); [intros|tauto]. - assert (df_eval_backprop_deriv σ df2 d4 (grad2)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - case_eq (df_eval_backprop_deriv σ df2 d4 (grad2)%R); [intros|tauto]. - simpl in IHdf2. - rewrite eqq8, H13, H15 in IHdf2. - unfold lift, lift2 in IHdf2; simpl in IHdf2. - inversion IHdf2. - f_equal. - rewrite (split_subvar d0 d8 d d5); trivial. - rewrite (split_subvar d2 d9 d1 d6); trivial. - rewrite (split_subvar d4 d10 d3 d7); trivial. - rewrite H8, H17; lra. - - Case "MatrixMinus"%string. - destruct H2. - match_option; [|tauto]. - assert (df_eval_backprop_deriv - σ df1 grad_env3 - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad1 i j + grad2 i j)%R) <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env1 (grad1)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 (grad2)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (- grad1 i j)%R) - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (- grad2 i j)%R) - d2 d4 d0). - rewrite eqq, eqq1, eqq3 in IHdf1. - simpl in IHdf1. - rewrite eqq0, eqq2, eqq4 in IHdf1. - unfold lift, lift2 in IHdf1; simpl in IHdf1. - inversion IHdf1. - assert (vartlookup d2 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) H). - assert (vartlookup d4 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s, DTfloat) H0). - assert (vartlookup d0 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) H1). - specialize (IHdf2 H7 H9 H10 H3). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - unfold lift, lift2; simpl. - assert (df_eval_backprop_deriv - σ df2 d0 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (- (grad1 i j + grad2 i j))%R) <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - match_option; [|tauto]. - assert (df_eval_backprop_deriv - σ df2 d2 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad1 i j)%R) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - case_eq (df_eval_backprop_deriv σ df2 d2 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad1 i j)%R)) - ; [intros |tauto]. - assert (df_eval_backprop_deriv σ df2 d4 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad2 i j)%R) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - case_eq (df_eval_backprop_deriv σ df2 d4 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad2 i j)%R)) - ; [intros | tauto]. - simpl in IHdf2. - replace - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (- grad1 i j + - grad2 i j)%R) with - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (- (grad1 i j + grad2 i j))%R) in IHdf2. - rewrite eqq8, H13, H15 in IHdf2. - unfold lift, lift2 in IHdf2; simpl in IHdf2. - inversion IHdf2. - f_equal. - rewrite (split_subvar d0 d8 d d5); trivial. - rewrite (split_subvar d2 d9 d1 d6); trivial. - rewrite (split_subvar d4 d10 d3 d7); trivial. - rewrite H8, H17; lra. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - - Case "VectorScalMult"%string. - destruct H2. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df1); intros. - specialize (H4 H2). - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H5 H3). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env3 - (vsum (fun j : {n' : nat | n' < n} => (d1 j * (grad1 j + grad2 j))%R)) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env1 - (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad1 j)%R)) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 - (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad2 j)%R)) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - specialize (IHdf1 - (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad1 j)%R)) - (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad2 j)%R)) - grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 - (fun j : {n' : nat | n' < n} => (d0 * grad1 j)%R) - (fun j : {n' : nat | n' < n} => (d0 * grad2 j)%R) - d4 d6 d2). - rewrite eqq, eqq3, eqq5 in IHdf1. - simpl in IHdf1. - replace - (@vsum floatish_R n (fun j : {n' : nat | (n' < n)%nat} => d1 j * grad1 j) + - @vsum floatish_R n (fun j : {n' : nat | (n' < n)%nat} => d1 j * grad2 j))%R with - (vsum (fun j : {n' : nat | n' < n} => (d1 j * (grad1 j + grad2 j))%R)) in IHdf1. - + rewrite eqq2, eqq4, eqq6 in IHdf1. - unfold lift, lift2 in IHdf1; simpl in IHdf1. - inversion IHdf1. - assert (vartlookup d2 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) H1). - assert (vartlookup d4 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s, DTfloat) H). - assert (vartlookup d6 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq6 (s, DTfloat) H0). - specialize (IHdf2 H11 H12 H9 H3). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - unfold lift, lift2; simpl. - assert (df_eval_backprop_deriv - σ df2 d2 - (fun j : {n' : nat | n' < n} => (d0 * (grad1 j + grad2 j))%R) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d4 - (fun j : {n' : nat | n' < n} => (d0 * grad1 j)%R) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - case_eq (df_eval_backprop_deriv σ df2 d4 - (fun j : {n' : nat | n' < n} => (d0 * grad1 j)%R)) - ; [intros | tauto]. - assert (df_eval_backprop_deriv σ df2 d6 - (fun j : {n' : nat | n' < n} => (d0 * grad2 j)%R) - <> None ). - apply backprop_deriv_fully_closed_not_none. - apply H3. - case_eq (df_eval_backprop_deriv σ df2 d6 - (fun j : {n' : nat | n' < n} => (d0 * grad2 j)%R)) - ; [intros | tauto]. - simpl in IHdf2. - replace - (fun i : {n' : nat | n' < n} => (d0 * grad1 i + d0 * grad2 i)%R) with - (fun j : {n' : nat | n' < n} => (d0 * (grad1 j + grad2 j))%R) in IHdf2. - * rewrite eqq10, H15, H17 in IHdf2. - unfold lift, lift2 in IHdf2; simpl in IHdf2. - inversion IHdf2. - f_equal. - rewrite (split_subvar d2 d10 d d7); trivial. - rewrite (split_subvar d4 d11 d3 d8); trivial. - rewrite (split_subvar d6 d12 d5 d9); trivial. - rewrite H10, H19; lra. - * apply FunctionalExtensionality.functional_extensionality; intros. - lra. - + simpl. - rewrite vsum_plus. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - - Case "MatrixScalMult"%string. - destruct H2. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df1); intros. - specialize (H4 H2). - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H5 H3). - match_option; [|tauto]. - assert (df_eval_backprop_deriv - σ df1 grad_env3 - (msum - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (d1 i j * (grad1 i j + grad2 i j))%R)) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env1 - (msum - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (d1 i j * grad1 i j)%R)) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 - (msum - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (d1 i j * grad2 i j)%R)) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - specialize (IHdf1 - (msum - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (d1 i j * grad1 i j)%R)) - (msum - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (d1 i j * grad2 i j)%R)) - grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (grad1 i j * d0)%R) - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (grad2 i j * d0)%R) - d4 d6 d2). - rewrite eqq, eqq3, eqq5 in IHdf1. - simpl in IHdf1. - replace - (@msum floatish_R n m - (fun (i : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) => - d1 i j * grad1 i j) + - @msum floatish_R n m - (fun (i : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) => - d1 i j * grad2 i j))%R with - (msum - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (d1 i j * (grad1 i j + grad2 i j))%R)) in IHdf1. - + rewrite eqq2, eqq4, eqq6 in IHdf1. - unfold lift, lift2 in IHdf1; simpl in IHdf1. - inversion IHdf1. - assert (vartlookup d2 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) H1). - assert (vartlookup d4 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s, DTfloat) H). - assert (vartlookup d6 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq6 (s, DTfloat) H0). - specialize (IHdf2 H11 H12 H9 H3). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - unfold lift, lift2; simpl. - assert (df_eval_backprop_deriv - σ df2 d2 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - ((grad1 i j + grad2 i j) * d0)%R) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d4 - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad1 i j * d0)%R) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - case_eq (df_eval_backprop_deriv σ df2 d4 - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad1 i j * d0)%R)) - ; [intros | tauto]. - assert (df_eval_backprop_deriv σ df2 d6 - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad2 i j * d0)%R) - <> None ). - apply backprop_deriv_fully_closed_not_none. - apply H3. - case_eq (df_eval_backprop_deriv σ df2 d6 - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad2 i j * d0)%R)) - ; [intros | tauto]. - simpl in IHdf2. - replace - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (grad1 i j * d0 + grad2 i j * d0)%R) with - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - ((grad1 i j + grad2 i j) * d0)%R) in IHdf2. - * rewrite eqq10, H15, H17 in IHdf2. - unfold lift, lift2 in IHdf2; simpl in IHdf2. - inversion IHdf2. - f_equal. - rewrite (split_subvar d2 d10 d d7); trivial. - rewrite (split_subvar d4 d11 d3 d8); trivial. - rewrite (split_subvar d6 d12 d5 d9); trivial. - rewrite H10, H19; lra. - * apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - + unfold msum. - rewrite vsum_plus. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vmap_nth. - rewrite vmap_nth. - rewrite vmap_nth. - rewrite vsum_plus. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - - Case "VectorApply"%string. - destruct H2. - simpler2. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H4 H3). - match_option. - match_option. - + match_option. - * match_option. - -- unfold lift. - repeat simpl_closed_backprop. - unfold lift2. - specialize (apply vectoro_to_ovector_forall_some_f eqq3);intros. - specialize (apply vectoro_to_ovector_forall_some_f eqq4);intros. - specialize (apply vectoro_to_ovector_forall_some_f eqq5);intros. - specialize (IHdf2 v1 v2 grad_env1 grad_env2 grad_env3). - cut_to IHdf2; try congruence. - rewrite eqq, eqq0, eqq1, eqq7, eqq8 in IHdf2; simpl in IHdf2. - replace (fun i : {n' : nat | n' < n} => (v1 i + v2 i)%R) with v0 in IHdf2. - rewrite eqq6 in IHdf2. - unfold lift, lift2 in IHdf2. - apply IHdf2. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H5 x);rewrite vmap_nth in H5; simpl in H5. - specialize (H6 x);rewrite vmap_nth in H6; simpl in H6. - specialize (H7 x);rewrite vmap_nth in H7; simpl in H7. - match_option_in H5; invcs H5. - match_option_in H6; invcs H6. - match_option_in H7; invcs H7. - rewrite eqq9 in eqq10; invcs eqq10. - rewrite eqq9 in eqq11; invcs eqq11. - lra. - -- specialize (apply vectoro_to_ovector_exists_None eqq5); intros. - destruct H5. - rewrite vmap_nth in e; simpl in e. - match_option_in e. - assert (df_eval [mk_env_entry (v, DTfloat) (d x)] (df_deriv df1 (v, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - * specialize (apply vectoro_to_ovector_exists_None eqq4); intros. - destruct H5. - rewrite vmap_nth in e; simpl in e. - match_option_in e. - assert (df_eval [mk_env_entry (v, DTfloat) (d x)] (df_deriv df1 (v, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - + specialize (apply vectoro_to_ovector_exists_None eqq3); intros. - destruct H5. - rewrite vmap_nth in e; simpl in e. - match_option_in e. - assert (df_eval [mk_env_entry (v, DTfloat) (d x)] (df_deriv df1 (v, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - - Case "MatrixApply"%string. - destruct H2. - simpler2. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H4 H3). - match_option. - match_option. - + match_option. - * match_option. - -- unfold lift. - repeat simpl_closed_backprop. - unfold lift2. - unfold matrixo_to_omatrix in *. - specialize (apply vectoro_to_ovector_forall_some_f eqq3);intros. - specialize (apply vectoro_to_ovector_forall_some_f eqq4);intros. - specialize (apply vectoro_to_ovector_forall_some_f eqq5);intros. - specialize (IHdf2 m1 m2 grad_env1 grad_env2 grad_env3). - cut_to IHdf2; try congruence. - rewrite eqq, eqq0, eqq1, eqq7, eqq8 in IHdf2; simpl in IHdf2. - replace (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (m1 i j + m2 i j)%R) with m0 in IHdf2. - rewrite eqq6 in IHdf2. - unfold lift, lift2 in IHdf2. - apply IHdf2. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H5 x); specialize (H6 x); specialize (H7 x). - unfold mmap in H5; unfold mmap in H6; unfold mmap in H7. - specialize (apply vectoro_to_ovector_forall_some_f H5);intros. - specialize (apply vectoro_to_ovector_forall_some_f H6);intros. - specialize (apply vectoro_to_ovector_forall_some_f H7);intros. - specialize (H8 x0); do 2 rewrite vmap_nth in H8. - specialize (H9 x0); do 2 rewrite vmap_nth in H9. - specialize (H10 x0); do 2 rewrite vmap_nth in H10. - rewrite matrix_zip_m_n in H8. - rewrite matrix_zip_m_n in H9. - rewrite matrix_zip_m_n in H10. - match_option_in H8. - rewrite eqq9 in H9; rewrite eqq9 in H10. - invcs H8; invcs H9; invcs H10. - lra. - -- specialize (apply vectoro_to_ovector_exists_None eqq5); intros; destruct H5. - specialize (apply vectoro_to_ovector_exists_None e); intros; destruct H5. - unfold mmap in e0. - do 2 rewrite vmap_nth in e0; simpl in e0. - rewrite matrix_zip_m_n in e0. - match_option_in e0. - assert (df_eval [mk_env_entry (v, DTfloat) (d x x0)] (df_deriv df1 (v, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - * specialize (apply vectoro_to_ovector_exists_None eqq4); intros; destruct H5. - specialize (apply vectoro_to_ovector_exists_None e); intros; destruct H5. - unfold mmap in e0. - do 2 rewrite vmap_nth in e0; simpl in e0. - rewrite matrix_zip_m_n in e0. - match_option_in e0. - assert (df_eval [mk_env_entry (v, DTfloat) (d x x0)] (df_deriv df1 (v, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - + specialize (apply vectoro_to_ovector_exists_None eqq3); intros; destruct H5. - specialize (apply vectoro_to_ovector_exists_None e); intros; destruct H5. - unfold mmap in e0. - do 2 rewrite vmap_nth in e0; simpl in e0. - rewrite matrix_zip_m_n in e0. - match_option_in e0. - assert (df_eval [mk_env_entry (v, DTfloat) (d x x0)] (df_deriv df1 (v, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - - Case "VLossfun"%string. - destruct H2. - simpler2. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H4 H3). - match_option. - match_option. - + match_option. - * match_option. - -- unfold lift. - repeat simpl_closed_backprop. - unfold lift2. - specialize (apply vectoro_to_ovector_forall_some_f eqq3);intros. - specialize (apply vectoro_to_ovector_forall_some_f eqq4);intros. - specialize (apply vectoro_to_ovector_forall_some_f eqq5);intros. - specialize (IHdf2 v0 v3 grad_env1 grad_env2 grad_env3). - cut_to IHdf2; try congruence. - rewrite eqq, eqq0, eqq1 in IHdf2. - rewrite eqq7, eqq8 in IHdf2. - simpl in IHdf2. - replace (fun i : {n' : nat | n' < n} => (v0 i + v3 i)%R) with v in IHdf2. - rewrite eqq6 in IHdf2. - unfold lift, lift2 in IHdf2. - apply IHdf2. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H5 x);rewrite vmap_nth in H5; simpl in H5. - specialize (H6 x);rewrite vmap_nth in H6; simpl in H6. - specialize (H7 x);rewrite vmap_nth in H7; simpl in H7. - match_option_in H5. - rewrite eqq9 in H6; rewrite eqq9 in H7. - invcs H5; invcs H6; invcs H7. - lra. - -- specialize (apply vectoro_to_ovector_exists_None eqq5); intros. - destruct H5. - rewrite vmap_nth in e; simpl in e. - match_option_in e. - assert ( df_eval [mk_env_entry (v1, DTfloat) (d x); - mk_env_entry (v2, DTfloat) (r x)] - (df_deriv df1 (v1, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - * specialize (apply vectoro_to_ovector_exists_None eqq4); intros. - destruct H5. - rewrite vmap_nth in e; simpl in e. - match_option_in e. - assert (df_eval [mk_env_entry (v1, DTfloat) (d x); - mk_env_entry (v2, DTfloat) (r x)] - (df_deriv df1 (v1, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - + specialize (apply vectoro_to_ovector_exists_None eqq3); intros. - destruct H5. - rewrite vmap_nth in e; simpl in e. - match_option_in e. - assert (df_eval [mk_env_entry (v1, DTfloat) (d x); mk_env_entry (v2, DTfloat) (r x)] - (df_deriv df1 (v1, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - - Case "MLossfun"%string. - destruct H2. - simpler2. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H4 H3). - match_option. - match_option. - + match_option. - * match_option. - -- unfold lift. - repeat simpl_closed_backprop. - unfold lift2. - specialize (apply vectoro_to_ovector_forall_some_f eqq3);intros. - specialize (apply vectoro_to_ovector_forall_some_f eqq4);intros. - specialize (apply vectoro_to_ovector_forall_some_f eqq5);intros. - specialize (IHdf2 m1 m2 grad_env1 grad_env2 grad_env3). - cut_to IHdf2; try congruence. - rewrite eqq, eqq0, eqq1 in IHdf2. - rewrite eqq7, eqq8 in IHdf2. - simpl in IHdf2. - replace (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (m1 i j + m2 i j)%R) with m0 in IHdf2. - rewrite eqq6 in IHdf2. - unfold lift, lift2 in IHdf2. - apply IHdf2. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H5 x); simpl in H5. - specialize (H6 x); simpl in H6. - specialize (H7 x); simpl in H7. - specialize (apply vectoro_to_ovector_forall_some_f H5);intros. - specialize (apply vectoro_to_ovector_forall_some_f H6);intros. - specialize (apply vectoro_to_ovector_forall_some_f H7);intros. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H8 x0); unfold mmap in H8;do 2 rewrite vmap_nth in H8; simpl in H8. - specialize (H9 x0); unfold mmap in H9;do 2 rewrite vmap_nth in H9; simpl in H9. - specialize (H10 x0); unfold mmap in H10;do 2 rewrite vmap_nth in H10; simpl in H10. - rewrite matrix_zip_m_n in H8. - rewrite matrix_zip_m_n in H9. - rewrite matrix_zip_m_n in H10. - match_option_in H8. - rewrite eqq9 in H9; rewrite eqq9 in H10. - invcs H8; invcs H9; invcs H10. - lra. - -- specialize (apply vectoro_to_ovector_exists_None eqq5); intros. - destruct H5. - specialize (apply vectoro_to_ovector_exists_None e); intros. - destruct H5. - unfold mmap in e0; do 2 rewrite vmap_nth in e0; simpl in e0. - rewrite matrix_zip_m_n in e0. - match_option_in e0. - assert ( df_eval [mk_env_entry (v1, DTfloat) (d x x0); - mk_env_entry (v2, DTfloat) (r x x0)] - (df_deriv df1 (v1, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - * specialize (apply vectoro_to_ovector_exists_None eqq4); intros. - destruct H5. - specialize (apply vectoro_to_ovector_exists_None e); intros. - destruct H5. - unfold mmap in e0; do 2 rewrite vmap_nth in e0; simpl in e0. - rewrite matrix_zip_m_n in e0. - match_option_in e0. - assert (df_eval [mk_env_entry (v1, DTfloat) (d x x0); - mk_env_entry (v2, DTfloat) (r x x0)] - (df_deriv df1 (v1, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - + specialize (apply vectoro_to_ovector_exists_None eqq3); intros. - destruct H5. - specialize (apply vectoro_to_ovector_exists_None e); intros. - destruct H5. - unfold mmap in e0; do 2 rewrite vmap_nth in e0; simpl in e0. - rewrite matrix_zip_m_n in e0. - match_option_in e0. - assert (df_eval [mk_env_entry (v1, DTfloat) (d x x0); - mk_env_entry (v2, DTfloat) (r x x0)] - (df_deriv df1 (v1, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - Qed. - - Lemma vectoro_to_ovector_eNone_None {A n} {vo:Vector (option A) n} : - {i | vo i = None} -> - vectoro_to_ovector vo = None. - Proof. - intros. - destruct H. - now apply (vectoro_to_ovector_None_None x). - Qed. - - Definition ConstSplitVector {T} (middle theend:nat) (part1 part2:T) : (Vector T theend) := - fun (i: {n':nat | n' < theend}%nat) => - if lt_dec (proj1_sig i) middle then part1 else part2. - - Definition mergeVectorZero {T} {n} (middle:nat) (part1 : Vector T n) (c:T) : (Vector T n) := - fun (i: {n':nat | n' < n}%nat) => - if lt_dec (proj1_sig i) middle then (part1 i) else c. - - Definition scaleUnitVector {T} (n:nat) (j : {n':nat | (n' < n)%nat}) (c:T) (zero:T) : Vector T n := - fun i => if (proj1_sig i) == (proj1_sig j) then c else zero%R. - - Lemma ConstSplitVectorSzero bound n (pf:bound < n) : - (ConstSplitVector (S bound) n 1%R 0%R) = - dfti_gen_plus (T:=DTVector n) (ConstSplitVector bound n 1%R 0%R) (UnitVector n (exist _ bound pf)). - Proof. - simpl. - apply functional_extensionality. - intros. - unfold ConstSplitVector, UnitVector, equiv_dec, nat_eq_eqdec; simpl. - destruct x as [x pff]; simpl. - destruct (lt_dec x (S bound)) - ; destruct (lt_dec x bound) - ; destruct (Nat.eq_dec x bound) - ; try lia; try lra. - Qed. - - Lemma mergeVectorSzero {n} (bound:nat) (pf:bound < n) (part1 : Vector R n): - let ind := (exist _ bound pf) in - (@mergeVectorZero (@float floatish_R) n (S bound) part1 0%R) = - dfti_gen_plus (T:=DTVector n) (@mergeVectorZero (@float floatish_R) n bound part1 0%R) - (scaleUnitVector n ind (part1 ind) 0%R). - Proof. - simpl. - apply functional_extensionality. - intros. - unfold mergeVectorZero, scaleUnitVector, equiv_dec, nat_eq_eqdec; simpl. - destruct x as [x pff]; simpl. - destruct (lt_dec x (S bound)) - ; destruct (lt_dec x bound) - ; destruct (Nat.eq_dec x bound) - ; try lia; try lra. - subst; simpl. - ring_simplify. - erewrite index_pf_irrel; eauto. - Qed. - - Lemma mergeVectorSzero_mat {n m} (bound:nat) pf (part1 : Matrix float n m) : - let ind := (exist _ bound pf) in - (@mergeVectorZero (Vector float m) n (S bound) part1 (ConstVector m 0%R)) = - dfti_gen_plus (T:=DTMatrix n m) (@mergeVectorZero (Vector float m) n bound part1 (ConstVector m 0%R)) - (scaleUnitVector (T:=Vector float m) n ind (part1 ind) (ConstVector m 0%R)). - Proof. - simpl. - do 2 (apply functional_extensionality; intros). - unfold mergeVectorZero, scaleUnitVector, ConstVector, equiv_dec, nat_eq_eqdec; simpl. - destruct x as [x pff]; simpl. - destruct (lt_dec x (S bound)) - ; destruct (lt_dec x bound) - ; destruct (Nat.eq_dec x bound) - ; simpl - ; try lia; try lra. - subst; simpl. - rewrite Rplus_0_l. - erewrite index_pf_irrel; eauto. - Qed. - - Lemma vsum_alt_eq {m:nat} (v:Vector R m) : vsum v = vector_fold_right Fplus 0%R v. - Proof. - apply vector_fold_right1_as_vector_fold_right. - unfold Datatypes.id; simpl; intros; lra. - Qed. - - Lemma vsum_cons {m:nat} x (v:Vector R m) : - vsum (vcons x v) = (x + vsum v)%R. - Proof. - repeat rewrite vsum_alt_eq. - apply vector_fold_right_vcons. - Qed. - - Lemma constSplitVectorZero {n} : - ConstSplitVector 0 n 1%R 0%R = ConstVector n 0%R. - Proof. - unfold ConstSplitVector, ConstVector; simpl. - now apply functional_extensionality. - Qed. - - Lemma df_eval_backprop_delta_by_unit_partvec_bounded {n} bound (pf:(bound<=n)%nat) - (σ:df_env) (df:DefinedFunction UnitAnn (DTVector n)) (s: SubVar) grad_env (grad d:Vector float n): - - fully_closed_over df - (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) -> - vartlookup grad_env (s, DTfloat) <> None -> - - (forall i : {n' : nat | n' < n}, - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - (scaleUnitVector n i (grad i) 0%R)) = Some (d i)) -> - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - (mergeVectorZero bound grad 0%R)) = Some (vsum (vfirstn d bound pf)). - Proof. - intros closed lo fa. - induction bound. - - replace (mergeVectorZero 0 grad 0%R) with (scalarMult (DTVector n) 0%R (mergeVectorZero 0 grad 0%R)). - + erewrite scalarMult_backprop_grad_scalar; try eassumption. - * simpl. - unfold df_eval_backprop_delta. - simpler2. - unfold lift. - simpl_closed_backprop. - f_equal. - vm_compute; lra. - * apply backprop_deriv_fully_closed_not_none; trivial. - * apply backprop_deriv_fully_closed_not_none; trivial. - + unfold scalarMult, mergeVectorZero. - apply functional_extensionality; intros; simpl. - lra. - - assert (pf2:bound < n) by lia. - assert (pf3:bound <= n) by lia. - rewrite (mergeVectorSzero _ pf2 _ ). - erewrite backprop_grad_sum; try eassumption. - specialize (IHbound pf3). - rewrite IHbound. - rewrite fa. - simpl. - f_equal. - destruct n; [lia | ]. - generalize (vector_Sn_split (vfirstn d (S bound) pf)); intros eqq1. - apply vec_eq_eq in eqq1. - simpl in *. - rewrite eqq1. - erewrite vfirstn_vdrop_last. - erewrite vlast_vfirstn. - rewrite vsum_cons. - rewrite Rplus_comm. - f_equal. - Qed. - - Lemma df_eval_backprop_delta_by_unit_partvec {n} - (σ:df_env) (df:DefinedFunction UnitAnn (DTVector n)) (s: SubVar) grad_env (grad d:Vector float n): - - fully_closed_over df - (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) -> - vartlookup grad_env (s, DTfloat) <> None -> - - (forall i : {n' : nat | n' < n}, - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - (scaleUnitVector n i (grad i) 0%R)) = Some (d i)) -> - df_eval_backprop_delta σ df (s, DTfloat) grad_env grad = - Some (vsum d). - Proof. - intros. - replace (grad) with (mergeVectorZero n grad 0%R). - - erewrite df_eval_backprop_delta_by_unit_partvec_bounded; try eassumption. - now rewrite vfirstn_eq. - - apply functional_extensionality; unfold mergeVectorZero; intros [x pff]; simpl. - destruct (lt_dec x n); trivial; lia. - Unshelve. - lia. - Qed. - - Lemma df_eval_backprop_delta_by_unit_parts {n} - (σ:df_env) (df:DefinedFunction UnitAnn (DTVector n)) (s: SubVar) grad_env (d:Vector float n): - - fully_closed_over df - (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) -> - vartlookup grad_env (s, DTfloat) <> None -> - - (forall i : {n' : nat | n' < n}, - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - - (UnitVector n i)) = Some (d i)) -> - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - (ConstVector n 1%R)) = Some (vsum d). - Proof. - intros. - replace (ConstVector n 1%R) with (mergeVectorZero n (ConstVector n 1%R) 0%R). - - erewrite df_eval_backprop_delta_by_unit_partvec_bounded; try eassumption. - now rewrite vfirstn_eq. - - apply functional_extensionality; unfold mergeVectorZero, ConstVector; intros [x pff]; simpl. - destruct (lt_dec x n); trivial; lia. - Unshelve. - lia. - Qed. - - Lemma scalarMult_mult {T} a b grad : scalarMult T a (scalarMult T b grad) = scalarMult T (a*b)%R grad. - Proof. - destruct T; simpl. - - lra. - - apply vec_eq_eq; intros ?; lra. - - do 2 (apply vec_eq_eq; intros ?); lra. - Qed. - - Corollary scalarMult_backprop_grad0 {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (s: SubVar) (grad_env :df_env) (grad : definition_function_types_interp T) : - let v := (s, DTfloat) in - vartlookup grad_env v <> None -> - df_eval_backprop_deriv σ df grad_env (scalarMult T 0%R grad) <> None -> - df_eval_backprop_delta σ df v grad_env (scalarMult T 0%R grad) = Some 0%R. - Proof. - simpl; intros. - (* This allows us to drop the (unneeded) assumption - df_eval_backprop_deriv σ df grad_env1 grad <> None - *) - replace (scalarMult T 0%R grad) with (scalarMult T 0%R (scalarMult T 0%R grad)). - - erewrite scalarMult_backprop_grad_scalar; try eassumption. - + simpl. - unfold df_eval_backprop_delta. - simpler2; simpl. - destruct (df_eval_backprop_deriv σ df grad_env (scalarMult T 0%R grad)); simpl; [| congruence]. - f_equal; lra. - + now rewrite scalarMult_mult, Rmult_0_l. - - now rewrite scalarMult_mult, Rmult_0_l. - Qed. - - Lemma scaleUnitVec_vec_plus_distr m n i (x y:Vector float m) c : - vec_eq c (dfti_gen_plus (T:=DTVector m) c c) -> - (scaleUnitVector n i (dfti_gen_plus (T:=DTVector m) x y) c) = - (dfti_gen_plus (T:=DTMatrix n m) (scaleUnitVector n i x c) (scaleUnitVector n i y c)). - Proof. - intros fa. - do 2 (apply functional_extensionality; intro). - unfold scaleUnitVector; simpl. - match_destr - ; try now specialize (fa x1); simpl in fa. - Qed. - - - Lemma df_eval_backprop_delta_by_unit_partmat_outer_bounded {n m} bound (pf:(bound<=n)%nat) - (σ:df_env) (df:DefinedFunction UnitAnn (DTMatrix n m)) (s: SubVar) grad_env (grad d:Matrix float n m): - - - fully_closed_over df - (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) -> - vartlookup grad_env (s, DTfloat) <> None -> - - (forall (i : {n' : nat | n' < n}), - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - (scaleUnitVector n i (grad i) (ConstVector m 0%R))) = Some (vsum (d i))) -> - - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - (mergeVectorZero bound grad (ConstVector m 0%R))) = Some (vsum (vfirstn (vmap vsum d) bound pf)). - Proof. - intros closed lo fa. - induction bound. - - rewrite vfirstn0, vsum_nil. - - replace (mergeVectorZero 0 grad (ConstVector m 0%R)) with (scalarMult (DTMatrix n m) 0%R (mergeVectorZero 0 grad (ConstVector m 0%R))). - + apply scalarMult_backprop_grad0; simpl in *; trivial. - apply backprop_deriv_fully_closed_not_none; trivial. - + unfold scalarMult, ConstVector, mergeVectorZero. - do 2 (apply functional_extensionality; intros); simpl. - lra. - - assert (pf2:bound < n) by lia. - assert (pf3:bound <= n) by lia. - simpl. - generalize (mergeVectorSzero_mat (m:=m) _ pf2 grad ); intros HH; unfold Vector in HH. - simpl float in *. - rewrite HH; clear HH. - erewrite backprop_grad_sum; try eassumption. - specialize (IHbound pf3). - rewrite IHbound. - rewrite fa. - simpl. - f_equal. - generalize (vector_Sn_split (vfirstn (vmap vsum d) (S bound) pf)); intros eqq1. - apply vec_eq_eq in eqq1. - simpl in *. - rewrite eqq1. - erewrite vfirstn_vdrop_last. - erewrite vlast_vfirstn. - rewrite vsum_cons. - rewrite Rplus_comm. - f_equal. - rewrite vmap_nth. - eauto. - Qed. - - Lemma df_eval_backprop_delta_by_unit_partmat_inner_bounded {n m} bound (pf:(bound<=m)%nat) - (σ:df_env) (df:DefinedFunction UnitAnn (DTMatrix n m)) (s: SubVar) grad_env (grad d:Matrix float n m) i: - - - fully_closed_over df - (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) -> - vartlookup grad_env (s, DTfloat) <> None -> - - (forall j : {n' : nat | n' < m}, - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - ((scaleUnitVector n i - (scaleUnitVector m j (grad i j) 0%R) (ConstVector m 0%R)))) = Some (d i j)) -> - - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - (scaleUnitVector n i (mergeVectorZero bound (grad i) 0%R) (ConstVector m 0%R))) = Some (vsum (vfirstn (d i) bound pf)). - Proof. - intros closed lo fa. - induction bound. - - rewrite vfirstn0, vsum_nil. - replace (scaleUnitVector n i (mergeVectorZero 0 (grad i) 0%R) (ConstVector m 0%R)) with (scalarMult (DTMatrix n m) 0%R (mergeVectorZero 0 grad (ConstVector m 0%R))). - + apply scalarMult_backprop_grad0; simpl in *; trivial. - apply backprop_deriv_fully_closed_not_none; trivial. - + unfold scalarMult, ConstVector, mergeVectorZero, scaleUnitVector. - do 2 (apply functional_extensionality; intros); simpl. - match_destr; lra. - - assert (pf2:bound < m) by lia. - assert (pf3:bound <= m) by lia. - simpl. - generalize (mergeVectorSzero _ pf2 (grad i) ); intros HH; unfold Vector in HH. - simpl float in *. - rewrite HH; clear HH. - rewrite (scaleUnitVec_vec_plus_distr m n i) by (intros ?; unfold ConstVector; simpl; lra). - erewrite backprop_grad_sum; try eassumption. - specialize (IHbound pf3). - simpl in *. - rewrite IHbound. - simpl. - rewrite fa. - simpl. - f_equal. - generalize (vector_Sn_split (vfirstn (d i) (S bound) pf)); intros eqq1. - apply vec_eq_eq in eqq1. - simpl in *. - rewrite eqq1. - erewrite vfirstn_vdrop_last. - erewrite vlast_vfirstn. - rewrite vsum_cons. - rewrite Rplus_comm. - f_equal. - Qed. - - Lemma df_eval_backprop_delta_by_unit_partmat {n m} - (σ:df_env) (df:DefinedFunction UnitAnn (DTMatrix n m)) (s: SubVar) grad_env - (grad d:Matrix float n m): - fully_closed_over - df - (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) - σ) -> - vartlookup grad_env (s, DTfloat) <> None -> - (forall (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) , - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - (scalarMult (DTMatrix n m) (grad i j) - (UnitMatrix n m i j))) = - Some (d i j)) -> - df_eval_backprop_delta σ df (s, DTfloat) grad_env grad = - Some (msum d). - Proof. - intros closed lo fa. - replace (grad) with (mergeVectorZero n grad (ConstVector m 0%R)). - - erewrite df_eval_backprop_delta_by_unit_partmat_outer_bounded; try eassumption. - + now rewrite vfirstn_eq. - + intros i. - specialize (fa i). - simpl in fa. - replace (grad i) with (mergeVectorZero m (grad i) 0%R). - 2: { - unfold mergeVectorZero; simpl; apply functional_extensionality; intros; destruct x. - simpl. - match_destr; lia. - } - - erewrite (df_eval_backprop_delta_by_unit_partmat_inner_bounded m (le_refl m) σ df s grad_env) - ; try eassumption. - * now rewrite vfirstn_eq. - * intros j. - specialize (fa j); simpl in *. - rewrite <- fa. - f_equal. - do 2 (apply functional_extensionality; intros). - unfold scaleUnitVector, ConstVector, UnitMatrix; simpl. - repeat match_destr; lra. - - apply functional_extensionality; unfold mergeVectorZero; intros [x pff]; simpl. - destruct (lt_dec x n); trivial; lia. - Unshelve. - lia. - Qed. - - Lemma df_eval_backprop_delta_by_unit_parts_mat {n m} - (σ:df_env) (df:DefinedFunction UnitAnn (DTMatrix n m)) (s: SubVar) grad_env - (d:Matrix float n m): - fully_closed_over - df - (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) - σ) -> - vartlookup grad_env (s, DTfloat) <> None -> - - (forall (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) , - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - - (UnitMatrix n m i j)) = Some (d i j)) -> - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - (ConstMatrix n m 1%R)) = Some (msum d). - Proof. - intros. - apply df_eval_backprop_delta_by_unit_partmat; trivial. - intros. - rewrite scalarMult_backprop_grad_scalar with (grad_env2 := grad_env); trivial. - unfold lift. - rewrite H1. - f_equal. - unfold scalarMult, ConstMatrix; simpl; lra. - apply backprop_deriv_fully_closed_not_none; trivial. - apply backprop_deriv_fully_closed_not_none; trivial. - Qed. - - Corollary scalarMult_backprop_list_env_iter_grad0 {T} (σ:df_env) (s: SubVar) (grad_env :df_env) (grad : definition_function_types_interp T) old n x l : - let v := (s, DTfloat) in - (forall j : {n' : nat | n' < n}, - fully_closed_over (x j) - (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ)) -> - vartlookup grad_env v = Some old -> - lift (fun e => subvar v e old) - (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv (Ann:=UnitAnn) σ (x i) env (scalarMult T 0%R grad)) (Some grad_env) l) = Some 0%R. - Proof. - simpl; intros. - revert grad_env old H0. - induction l; simpl; intros. - - unfold subvar; simpl. - rewrite H0; f_equal. - lra. - - unfold lift in *. - case_eq (df_eval_backprop_deriv σ (x a) grad_env (scalarMult T 0%R grad)) - ; intros. - + apply IHl. - * generalize (scalarMult_backprop_grad0 σ (x a) s grad_env grad); intros HH. - simpl in HH. - cut_to HH; try congruence. - unfold df_eval_backprop_delta in HH. - rewrite H0 in HH. - rewrite H1 in HH. - simpl in HH. - invcs HH. - unfold subvar in H3; simpl in H3. - simpler2. - rewrite eqq in eqq0; invcs eqq0; subst. - f_equal. - lra. - + eelim backprop_deriv_fully_closed_not_none; eauto. - Qed. - - Lemma list_env_iter_gen_delta0 {n} (s: SubVar) (init_env : df_env) - (f: {n' : nat | n' < n} -> df_env -> option df_env) (l : list {n' : nat | n' < n}): - let v := (s,DTfloat) in - vartlookup init_env v <> None -> - (forall (i : {n' : nat | n' < n}) (env : df_env), - (f i env) <> None /\ - (vartlookup env v <> None -> - match f i env with - | Some xenv => vartlookup xenv v <> None - | _ => True - end ) /\ - (In i l -> vartlookup env v <> None -> - lift2 (fun e val => subvar v e val) (f i env) - (vartlookup env v) = Some 0%R)) -> - lift2 (fun e old => subvar v e old) - (list_env_iter f (Some init_env) l) - (vartlookup init_env v) = Some 0%R. - Proof. - simpl; intros. - revert init_env H H0. - induction l. - - intros. - simpl; f_equal. - unfold subvar; simpl. - match_option; [|tauto]. - f_equal. - lra. - - intros. - assert (H0c := H0). - specialize (H0 a init_env). - destruct H0; destruct H1. - simpl. - case_eq (f a init_env); [intros|tauto]. - specialize (IHl d). - replace (vartlookup init_env (s, DTfloat)) with (vartlookup d (s, DTfloat)). - + apply IHl. - * specialize (H1 H). - now rewrite H3 in H1. - * intros. - specialize (H0c i env). - destruct H0c; destruct H5. - split; trivial. - split; trivial. - intros. - cut_to H6; trivial. - simpl; tauto. - + case_eq (vartlookup init_env (s, DTfloat)); [intros|tauto]. - rewrite H3, H4 in H2; simpl in H2. - cut_to H2; try tauto. - unfold lift2 in H2; simpl in H2. - inversion H2. - unfold subvar in H6; simpl in H6. - rewrite H3 in H1. - specialize (H1 H). - case_eq ( vartlookup d (s, DTfloat)); [intros|tauto]. - rewrite H5 in H6. - f_equal; lra. - congruence. - Qed. - - Lemma list_env_iter_gen_delta {n} (s: SubVar) (init_env : df_env) (old : float) - (f: {n' : nat | n' < n} -> df_env -> option df_env) (i0 : {n' : nat | n' < n}): - let v := (s,DTfloat) in - vartlookup init_env v = Some old -> - (forall (i : {n' : nat | n' < n}) (env : df_env), - (f i env) <> None /\ - (vartlookup env v <> None -> - match f i env with - | Some xenv => vartlookup xenv v <> None - | _ => True - end ) /\ - (forall (env2 : df_env), - match vartlookup env v, vartlookup env2 v, f i env, f i env2 with - | Some val1, Some val2, Some xenv, Some xenv2 => - subvar v xenv val1 = subvar v xenv2 val2 - | _, _, _, _ => True - end) /\ - ( i <> i0 -> vartlookup env v <> None -> - lift2 (fun e val => subvar v e val) (f i env) - (vartlookup env v) = Some 0%R)) -> - lift (fun e => subvar v e old) (f i0 init_env) = - lift (fun e => subvar v e old) - (list_env_iter f - (Some init_env) (bounded_seq0 n)). - Proof. - simpl; intros. - unfold bounded_seq0. - destruct (bounded_seq_break_at 0 n i0) as [b [c [eqq1 [fa1 fa2]]]]; [lia |]. - rewrite eqq1. - rewrite list_env_iter_app; simpl. - match_option. - - - assert (eqq': subvar (s, DTfloat) d old = 0%R). - + generalize (list_env_iter_gen_delta0 s init_env f b); intros. - simpl in H1. - cut_to H1; try congruence. - * unfold lift2 in H1. - rewrite H, eqq in H1. - now inversion H1. - * intros. - specialize (H0 i env). - destruct H0; destruct H2; destruct H3. - split; trivial. - split; trivial. - intros. - cut_to H4; trivial. - rewrite Forall_forall in fa1. - specialize (fa1 i H5). - intro eq1; rewrite eq1 in fa1; lia. - + generalize (vartlookup_list_env_iter s f b); intros vart. - * specialize (vart init_env d eqq). - assert (vartinit: vartlookup init_env (s, DTfloat) <> None) by congruence. - specialize (vart vartinit). - assert (f i0 d <> None) by apply H0. - case_eq (f i0 d); [intros | tauto]. - generalize (list_env_iter_gen_delta0 s d0 f c); simpl; intros. - cut_to H3; try congruence. - -- unfold lift at 2. - unfold lift2 in H3. - match_option. - ++ rewrite eqq0 in H3. - match_option_in H3. - rewrite (split_subvar d0 d1 old d2);trivial. - inversion H3. - rewrite H5. - unfold lift. - match_option. - ** f_equal. - case_eq (vartlookup d (s, DTfloat)); intros. - --- rewrite (split_subvar d d0 old d4); trivial. - rewrite eqq'. - specialize (H0 i0 init_env). - destruct H0; destruct H6; destruct H7. - specialize (H7 d). - rewrite H, H4, eqq3, H2 in H7. - rewrite H7; lra. - --- cut_to vart; [tauto|]. - intros. - specialize (H0 i env). - destruct H0; destruct H8. - specialize (H8 H7). - now rewrite H6 in H8. - ** specialize (H0 i0 init_env). - destruct H0; congruence. - ++ generalize (list_env_iter_total_fun f d0 c); intros. - cut_to H4; try congruence. - apply H0. - -- cut_to vart. - ++ specialize (H0 i0 d). - destruct H0; destruct H4. - rewrite H2 in H4. - apply H4; trivial. - ++ intros. - specialize (H0 i env). - destruct H0;destruct H6. - specialize (H6 H5). - now rewrite H4 in H6. - -- intros. - specialize (H0 i env). - split. - apply H0. - split. - apply H0. - intros. - assert (i <> i0). - ++ rewrite Forall_forall in fa2. - specialize (fa2 i H4). - intro eq1; rewrite eq1 in fa2; lia. - ++ now apply H0. - - generalize (list_env_iter_total_fun f init_env b); intros. - cut_to H1; trivial. - tauto. - intros. - apply H0. - Qed. - - Lemma list_env_iter_vec_delta {n} (σ:df_env) - (x:Vector (DefinedFunction UnitAnn DTfloat) n) (s: SubVar) grad_env - (i0 : {n' : nat | n' < n}) (old : float) : - let v := (s, DTfloat) in - vartlookup grad_env v = Some old -> - (forall (j: {n' : nat | n' < n}) , - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over (x j) vl) -> - lift (fun e => subvar v e old) - (df_eval_backprop_deriv σ (x i0) grad_env 1%R) = - lift (fun e => subvar v e old) - (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (UnitVector n i0 i)) - (Some grad_env) (bounded_seq0 n)). - Proof. - simpl; intros. - generalize (list_env_iter_gen_delta - s grad_env old - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (UnitVector n i0 i)) - i0). - simpl; intros. - specialize (H1 H). - rewrite <- H1. - - unfold UnitVector; simpl. - now destruct (equiv_dec (` i0) (` i0)); [|congruence]. - - clear H1. - intros. - split; [|split]. - + apply backprop_deriv_fully_closed_not_none; auto. - + intros. - match_option. - now apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in eqq;trivial. - + split. - * intros. - do 4 match_option. - generalize (backprop_indep_env σ (x i) s env env2 - (UnitVector n i0 i)); simpl; intros HH. - cut_to HH; trivial; try congruence. - unfold df_eval_backprop_delta in HH. - rewrite eqq,eqq0,eqq1,eqq2 in HH. - unfold lift in HH. - now inversion HH. - * intros. - unfold UnitMatrix; simpl. - destruct (equiv_dec (` i) (` i0)). - elim H1; destruct i; destruct i0; simpl in *; red in e; subst; apply index_pf_irrel. - unfold lift2. - generalize (scalarMult_backprop_grad0 σ (x i) s env 0%R); simpl; intros. - unfold df_eval_backprop_delta in H3. - specialize (H3 H2). - replace (0 * 0)%R with 0%R in H3 by lra. - match_option. - match_option; [|tauto]. - -- rewrite eqq0 in H3. - replace (UnitVector n i0 i) with Fzero in eqq; simpl in eqq. - ++ rewrite eqq in H3; unfold lift in H3. - apply H3; congruence. - ++ unfold UnitVector; simpl. - destruct (equiv_dec (` i) (` i0)); [|trivial]. - elim H1; destruct i; destruct i0; simpl in *; red in e; subst; apply index_pf_irrel. - -- assert (df_eval_backprop_deriv σ (x i) env (UnitVector n i0 i) <> None). - now apply backprop_deriv_fully_closed_not_none. - tauto. - Qed. - - - Lemma list_env_iter_backprop_indep_env_vec {m} (σ:df_env) - (vecdf: Vector (DefinedFunction UnitAnn DTfloat) m) - (s:SubVar) (env env2:df_env) (grad: Vector float m) - (old1 old2: float) : - let v := (s, DTfloat) in - vartlookup env v = Some old1 -> - vartlookup env2 v = Some old2 -> - (let vl := map (fun ve => projT1 ve) σ in - forall j : {m' : nat | m' < m}, - fully_closed_over (vecdf j) vl) -> - lift (fun e => subvar (s, DTfloat) e old1) - (list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (vecdf j) env0 (grad j)) - (Some env) (bounded_seq0 m)) = - lift (fun e => subvar (s, DTfloat) e old2) - (list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (vecdf j) env0 (grad j)) - (Some env2) (bounded_seq0 m)). - Proof. - intros. - subst v. - revert old1 old2 env env2 H H0. - induction (bounded_seq0 m). - - intros. - simpl. - unfold subvar; simpl. - rewrite H, H0. - f_equal; lra. - - intros. - simpl. - case_eq (df_eval_backprop_deriv σ (vecdf a) env (grad a)); intros. - case_eq (df_eval_backprop_deriv σ (vecdf a) env2 (grad a)); intros. - case_eq (vartlookup d (s, DTfloat)); intros. - case_eq (vartlookup d0 (s, DTfloat)); intros. - + specialize (IHl d1 d2 d d0 H4 H5). - unfold lift. - do 2 match_option. - * rewrite eqq, eqq0 in IHl. - unfold lift in IHl. - inversion IHl. - f_equal. - rewrite (split_subvar d d3 old1 d1); trivial. - rewrite (split_subvar d0 d4 old2 d2); trivial. - rewrite H7. - generalize (backprop_indep_env - σ (vecdf a) s - env env2 (grad a)); simpl; intros. - cut_to H6; trivial; try congruence. - unfold df_eval_backprop_delta in H6. - rewrite H, H0, H2, H3 in H6. - unfold lift in H6. - inversion H6. - rewrite H9; lra. - * generalize (list_env_iter_total_fun - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (vecdf j) env0 (grad j)) - d0 l); intros. - cut_to H6; [tauto|]. - intros. - apply backprop_deriv_fully_closed_not_none; auto. - * generalize (list_env_iter_total_fun - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (vecdf j) env0 (grad j)) - d l); intros. - cut_to H6; [tauto|]. - intros. - apply backprop_deriv_fully_closed_not_none; auto. - + apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in H3 - ;trivial; congruence. - + apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in H2 - ;trivial; congruence. - + assert (df_eval_backprop_deriv σ (vecdf a) env2 (grad a) <> None) by - (apply backprop_deriv_fully_closed_not_none; auto); tauto. - + assert (df_eval_backprop_deriv σ (vecdf a) env (grad a) <> None) by - (apply backprop_deriv_fully_closed_not_none; auto); tauto. - Qed. - - Lemma list_env_iter_mat_delta {n m} (σ:df_env) - (x:Matrix (DefinedFunction UnitAnn DTfloat) n m) (s: SubVar) grad_env - (i0 : {n' : nat | n' < n}) - (j0 : {m' : nat | m' < m}) (old : float) : - let v := (s, DTfloat) in - vartlookup grad_env v = Some old -> - (forall (i: {n' : nat | n' < n}) (j: {m' : nat | m' < m}) , - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over (x i j) vl) -> - lift (fun e => subvar v e old) - (df_eval_backprop_deriv σ (x i0 j0) grad_env 1%R) = - lift (fun e => subvar v e old) - (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv - σ (x i j) env0 - (UnitMatrix n m i0 j0 i j)) (Some env) - (bounded_seq0 m)) (Some grad_env) (bounded_seq0 n)). - Proof. - simpl; intros. - generalize (list_env_iter_gen_delta - s grad_env old - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv - σ (x i j) env0 - (UnitMatrix n m i0 j0 i j)) (Some env) - (bounded_seq0 m)) i0); simpl. - intros. - specialize (H1 H). - rewrite <- H1. - - clear H1. - generalize (list_env_iter_gen_delta - s grad_env old - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i0 j) env0 (UnitMatrix n m i0 j0 i0 j)) - j0); simpl. - intros. - specialize (H1 H). - rewrite <- H1. - + unfold UnitMatrix; simpl. - destruct (equiv_dec (` i0) (` i0)); [|congruence]. - now destruct (equiv_dec (` j0) (` j0)); [|congruence]. - + clear H1. - intros. - split; [|split]. - * apply backprop_deriv_fully_closed_not_none; auto. - * intros. - match_option. - now apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in eqq;trivial. - * split. - -- intros. - do 4 match_option. - generalize (backprop_indep_env σ (x i0 i) s env env2 - (UnitMatrix n m i0 j0 i0 i)); simpl; intros HH. - cut_to HH; trivial; try congruence. - unfold df_eval_backprop_delta in HH. - rewrite eqq,eqq0,eqq1,eqq2 in HH. - unfold lift in HH. - now inversion HH. - -- intros. - unfold UnitMatrix; simpl. - destruct (equiv_dec (` i0) (` i0)); [|congruence]. - destruct (equiv_dec (` i) (` j0)). - elim H1; destruct i; destruct j0; simpl in *; red in e0; subst; apply index_pf_irrel. - unfold lift2. - generalize (scalarMult_backprop_grad0 σ (x i0 i) s env 0%R); simpl; intros. - unfold lift2. - replace (0 * 0)%R with 0%R in H3 by lra. - unfold df_eval_backprop_delta in H3. - match_option. - match_option. - rewrite eqq0, eqq in H3; unfold lift in H3. - cut_to H3; congruence. - rewrite eqq0 in H3. - apply H3; trivial. - tauto. - congruence. - assert (df_eval_backprop_deriv σ (x i0 i) env 0%R <> None). - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - - split. - generalize (list_env_iter_total_fun - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (UnitMatrix n m i0 j0 i j)) - env (bounded_seq0 m)); intros. - apply H2. - intros. - apply backprop_deriv_fully_closed_not_none; trivial. - split. - + intros. - match_option. - apply (vartlookup_list_env_iter - s (fun (j : {m' : nat | m' < m}) - (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (UnitMatrix n m i0 j0 i j)) - (bounded_seq0 m) env d); trivial. - intros. - apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in H3; trivial. - + split. - * intros. - do 4 match_option. - generalize (list_env_iter_backprop_indep_env_vec - σ (x i) s env env2 - (UnitMatrix n m i0 j0 i) - d d0); simpl; intros. - specialize (H2 eqq eqq0). - specialize (H2 (H0 i)). - rewrite eqq1, eqq2 in H2. - unfold lift in H2. - now inversion H2. - * intros. - replace (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (UnitMatrix n m i0 j0 i j)) - with - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 - (scalarMult DTfloat 0%R 0%R)). - case_eq (vartlookup env (s, DTfloat)); [intros | tauto]. - apply scalarMult_backprop_list_env_iter_grad0; trivial. - apply functional_extensionality; intros. - apply functional_extensionality; intros. - f_equal. - unfold scalarMult, UnitMatrix; simpl. - destruct (equiv_dec (` i) (` i0)). - red in e. - elim H2; destruct i; destruct i0; simpl in *; subst. - apply index_pf_irrel. - lra. - Qed. - - Lemma list_env_iter_matvec_delta {m n} (σ:df_env) - (df2:DefinedFunction UnitAnn (DTVector m)) (s: SubVar) grad_env - (i0 : {n' : nat | n' < n}) - (j0 : {m' : nat | m' < m}) (old : float) : - vartlookup grad_env (s, DTfloat) = Some old -> - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df2 vl -> - (lift (fun e => subvar (s, DTfloat) e old) - (df_eval_backprop_deriv - σ df2 grad_env - (UnitVector m j0)) = - (lift (fun e => subvar (s, DTfloat) e old) - (list_env_iter - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv - σ df2 env - ((UnitMatrix n m i0 j0) i)) - (Some grad_env) - (bounded_seq0 n)))). - Proof. - simpl; intros. - generalize (list_env_iter_gen_delta - s grad_env old - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv - σ df2 env - ((UnitMatrix n m i0 j0) i)) - i0). - simpl; intros. - specialize (H1 H). - rewrite <- H1. - - unfold UnitMatrix, UnitVector; simpl. - now destruct (equiv_dec (` i0) (` i0)); [|congruence]. - - clear H1. - intros. - split; [|split]. - + apply backprop_deriv_fully_closed_not_none; auto. - + intros. - match_option. - now apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in eqq;trivial. - + split. - * intros. - do 4 match_option. - generalize (backprop_indep_env σ df2 s env env2 - (UnitMatrix n m i0 j0 i)); simpl; intros HH. - cut_to HH; trivial; try congruence. - unfold df_eval_backprop_delta in HH. - rewrite eqq,eqq0,eqq1,eqq2 in HH. - unfold lift in HH. - now inversion HH. - * intros. - unfold UnitMatrix; simpl. - destruct (equiv_dec (` i) (` i0)). - elim H1; destruct i; destruct i0; simpl in *; red in e; subst; apply index_pf_irrel. - unfold lift2. - generalize (scalarMult_backprop_grad0 σ df2 s env (ConstVector m 0%R)); simpl; intros. - unfold df_eval_backprop_delta in H3. - specialize (H3 H2). - match_option. - match_option; [|tauto]. - -- rewrite eqq0 in H3. - replace (fun i : {n' : nat | n' < m} => (0 * ConstVector m 0 i)%R) with - (fun i : {n' : nat | n' < m} => 0%R) in H3. - ++ rewrite eqq in H3; unfold lift in H3. - apply H3; congruence. - ++ apply functional_extensionality; intros. - unfold ConstVector; simpl. - lra. - -- assert (df_eval_backprop_deriv σ df2 env (fun _ => 0%R) <> None). - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - Qed. - - Lemma vmap_eta {A B} {n} (f:A->B) (d:Vector A n) : vmap f d = vmap (fun x => f x) d. - Proof. - now apply vmap_ext. - Qed. - - Lemma vsum_eta {n} (d:Vector float n) : vsum d = vsum (fun i => d i). - Proof. - now apply vsum_ext. - Qed. - - Lemma msum_unitvector m n x d1 d0 : - msum - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (UnitVector m x i * d1 i j * d0 j)%R) = - vsum - (fun j : {n' : nat | n' < n} => - (d1 x j * d0 j)%R). - Proof. - unfold msum. - - transitivity ( - vsum - (fun i : {n' : nat | n' < m} => - vsum - (fun (j : {m' : nat | m' < n}) => (UnitVector m x i * d1 i j * d0 j)%R))). - { apply vsum_ext; intros ?. - now rewrite vmap_nth. - } - transitivity ( - vsum - (fun i : {n' : nat | n' < m} => - ((UnitVector m x i * vsum - (fun (j : {m' : nat | m' < n}%nat) => d1 i j * d0 j))%R))). - { - apply vsum_ext; intros ?. - rewrite vsum_mult. - apply vsum_ext; intros ?. - lra. - } - transitivity ( - vsum - (fun i : {n' : nat | n' < m} => - (vsum (fun j : {m' : nat | (m' < n)%nat} => d1 i j * d0 j) * UnitVector m x i)%R)). - { - apply vsum_ext; intros ?; lra. - } - now rewrite vsum_unitvector. - Qed. - - Lemma vsum_as_sum {n} (v:Vector float n) : vsum v = fold_right Rplus R0 (vector_to_list v). - Proof. - rewrite vsum_alt_eq. - unfold vector_to_list, vector_fold_right. - induction n. - - rewrite vector_fold_right_dep_0; trivial. - - repeat rewrite vector_fold_right_dep_Sn. - now rewrite IHn. - Qed. - - Lemma msum_as_sum {m n} (mat:Matrix float m n) : msum mat = fold_right Rplus R0 (matrix_to_list mat). - Proof. - unfold msum, matrix_to_list, matrix_to_list_list. - transitivity - (vsum (fun i => fold_right Rplus R0 (vector_to_list (mat i)))). - { apply vsum_ext; intros [??]. - now rewrite vmap_nth, vsum_as_sum. - } - rewrite vsum_as_sum. - rewrite fold_right_plus_concat. - rewrite map_vector_to_list_vmap. - f_equal. - apply vector_to_list_ext. - intros [??]. - now rewrite vmap_nth. - Qed. - - Lemma transpose_perm_bounded {m n : nat} (mat : Matrix float m n) bound_m pf_m bound_n pf_n: - Permutation - (concat - (vector_fold_right_bounded_dep (fun _ : nat => Datatypes.cons) [] - (fun i : {n' : nat | n' < m} => - vector_fold_right_bounded_dep (fun _ : nat => Datatypes.cons) [] (mat i) bound_n pf_n) bound_m pf_m)) - (concat - (vector_fold_right_bounded_dep (fun _ : nat => Datatypes.cons) [] - (fun i : {n' : nat | n' < n} => - vector_fold_right_bounded_dep (fun _ : nat => Datatypes.cons) [] (transpose mat i) bound_m pf_m) bound_n pf_n)). - Proof. - Hint Constructors Permutation : fml. - revert bound_n pf_n. - induction bound_m; intros; simpl. - - induction bound_n; simpl; trivial. - - rewrite IHbound_m. - clear IHbound_m. - induction bound_n; simpl; trivial. - rewrite <- IHbound_n. - clear IHbound_n. - apply Permutation_cons; trivial. - unfold transpose; simpl. - repeat rewrite <- app_ass. - apply Permutation_app; trivial. - apply Permutation_app_comm. - Qed. - - Lemma transpose_perm {m n} (mat : Matrix float m n) : - Permutation (matrix_to_list mat) (matrix_to_list (transpose mat)). - Proof. - apply transpose_perm_bounded. - Qed. - - Lemma msum_transpose {m n} (mat : Matrix float m n) : - msum mat = msum (transpose mat). - Proof. - repeat rewrite msum_as_sum. - apply fold_right_perm; intros; try lra. - apply transpose_perm. - Qed. - - Ltac match_nested_case := - match goal with - | [|- context[match match ?x with _ => _ end with _ => _ end]] => - let eqq := fresh "eqq" in - case_eq x - ; [intros ? eqq | intros eqq] - end. - - Ltac match_nested_case_in H := - match H with - | context[match match ?x with _ => _ end with _ => _ end] => - let eqq := fresh "eqq" in - case_eq x - ; [intros ? eqq | intros eqq] - ; rewrite eqq in H - end. - - Theorem df_eval_deriv_genvar_same (σ:df_env) (df:DefinedFunction UnitAnn DTfloat) (v:SubVar) : - let vl := map (fun ve => projT1 ve) σ in - is_scalar_function df -> - fully_closed_over df vl -> - let forward := df_eval_deriv_gen_top σ df (v, DTfloat) in - lift transpose_lifted_type forward = df_eval σ (df_deriv df (v,DTfloat)). - Proof. - simpl. - intros is_scalar. - generalize is_scalar. - pattern df. - revert df is_scalar. - DefinedFunction_scalar_cases (apply is_scalar_function_ind) Case; simpl; trivial; intros - ; try - ( cut_to H; trivial - ;rewrite <- H - ;assert (df_eval σ e <> None) by (apply eval_fully_closed_not_none; trivial) - ;unfold lift, df_eval_deriv_gen_top; simpl - ;match_nested_case; [|tauto] - ;now match_nested_case). - - - Case "Var"%string. - match_nested_case. - + red in e. - inversion e; subst. - now refl_simpler; simpl. - + now intros. - - Case "Plus"%string. - destruct H1. - destruct is_scalar. - cut_to H; trivial. - cut_to H0; trivial. - unfold lift, df_eval_deriv_gen_top in *; simpl in *. - match_option_in H. - + match_option_in H0. - * rewrite <- H. - now rewrite <- H0. - * assert ( df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - - Case "Minus"%string. - destruct H1. - destruct is_scalar. - cut_to H; trivial. - cut_to H0; trivial. - unfold lift, df_eval_deriv_gen_top in *; simpl in *. - match_option_in H. - + match_option_in H0. - * rewrite <- H. - now rewrite <- H0. - * assert ( df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - - Case "Times"%string. - destruct H1. - destruct is_scalar. - cut_to H; trivial. - cut_to H0; trivial. - unfold lift, df_eval_deriv_gen_top in *; simpl in *. - match_nested_case; trivial. - match_option_in H. - + match_nested_case. - * match_option_in H0. - -- rewrite <- H. - now rewrite <- H0. - -- assert ( df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - * assert (df_eval σ r <> None) - ; [apply eval_fully_closed_not_none; trivial | tauto]. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - - Case "Divide"%string. - destruct H1. - destruct is_scalar. - cut_to H; trivial. - cut_to H0; trivial. - unfold lift, df_eval_deriv_gen_top in *; simpl in *. - match_nested_case; trivial. - match_option_in H. - + match_nested_case. - * match_option_in H0. - -- rewrite <- H. - now rewrite <- H0. - -- assert ( df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - * assert (df_eval σ r <> None) - ; [apply eval_fully_closed_not_none; trivial | tauto]. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - + assert (df_eval σ l <> None) - ; [apply eval_fully_closed_not_none; trivial | tauto]. - - Case "Sign"%string. - unfold lift, df_eval_deriv_gen_top in *; simpl in *. - match_nested_case; [trivial|]. - assert (df_eval_deriv_genvar σ e [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - - Case "PSign"%string. - unfold lift, df_eval_deriv_gen_top in *; simpl in *. - match_nested_case; [trivial|]. - assert (df_eval_deriv_genvar σ e [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - - Case "Max"%string. - destruct is_scalar; destruct H1. - cut_to H; trivial. - cut_to H0; trivial. - unfold lift, df_eval_deriv_gen_top in *; simpl in *. - assert (df_eval σ l <> None) by (apply eval_fully_closed_not_none; trivial). - assert (df_eval σ r <> None) by (apply eval_fully_closed_not_none; trivial). - match_nested_case; [|tauto]. - match_nested_case; [|tauto]. - match_option_in H. - match_option_in H0. - + rewrite <- H. - rewrite <- H0. - unfold pos_sign; simpl. - case_eq (Rle_dec d d0); intros; f_equal. - destruct (Rge_dec (d0 - d) 0); lra. - destruct (Rge_dec (d0 - d) 0); lra. - + assert (df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - Qed. - - Lemma yay {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (hs:has_scalar_functions df) (s: SubVar) grad_env : - let v := (s, DTfloat) in - vartlookup grad_env v <> None -> - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> - let forward := df_eval_deriv_gen_top σ df v in - let backward := df_eval_backward_gen_top σ df v grad_env in - lift transpose_lifted_type forward = backward. - Proof. - simpl. - intros vin closed. - revert grad_env vin closed. - unfold df_eval_deriv_gen_top, df_eval_backward_gen_top. - pattern T, df. - revert T df hs. - DefinedFunction_cases (apply DefinedFunction_ind_unit_has_scalar_functions) Case - ; simpl; intros. - - Case "Number"%string. - unfold subvar; simpl. - match_destr; [ | tauto]. - f_equal; lra. - - Case "Constant"%string. - unfold subvar; simpl. - match_destr; simpl. - + match_destr; [ | tauto]. - f_equal; lra. - + match_destr; [ | tauto]. - erewrite vectoro_to_ovector_forall_some_b_strong - ; simpl; trivial; intros. - unfold ConstVector. - f_equal; lra. - + match_destr; [ | tauto]. - unfold matrixo_to_omatrix. - repeat (erewrite vectoro_to_ovector_forall_some_b_strong - ; simpl; trivial; intros). - unfold ConstMatrix. - f_equal; lra. - - Case "DVector"%string. - unfold lift, two_vector_env_iter_alt. - case_eq (vartlookup grad_env (s, DTfloat)); [intros|tauto]. - match_option. - + specialize (apply vectoro_to_ovector_forall_some_f eqq); intros. - symmetry. - apply vectoro_to_ovector_forall_some_b_strong; intros. - unfold snd in H. - specialize (H i grad_env); simpl in H. - rewrite vforall_forall in closed. - assert (closedb := closed). - specialize (closed i). - specialize (H vin closed); simpl in H. - rewrite H0 in H. - specialize (H1 i); simpl in H1. - rewrite H1 in H; simpl in H. - unfold lift in H. - match_option_in H. - specialize (apply vectoro_to_ovector_forall_some_f eqq); intros. - specialize (H2 i); simpl in H2. - destruct i; simpl. - unfold UnitVector; simpl. - generalize (list_env_iter_total_fun - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env - (if equiv_dec (` i) x0 then 1%R else 0%R)) - grad_env (bounded_seq0 n)); intros. - cut_to H3. - match_option; [|tauto]. - * rewrite H. - generalize (list_env_iter_vec_delta σ x s grad_env - (exist (fun n' : nat => n' < n) x0 l) d). - intros. - simpl in H4. - specialize (H4 H0 closedb). - rewrite eqq0 in H4. - unfold UnitVector in H4; simpl in H4. - rewrite eqq1 in H4. - unfold lift in H4. - symmetry; trivial. - * intros. - apply backprop_deriv_fully_closed_not_none. - apply closedb. - + specialize (vectoro_to_ovector_exists_None eqq); intros. - destruct H1. - generalize (eval_deriv_genvar_fully_closed_not_none σ (x x0) [mk_env_entry (s, DTfloat) 1%R]); intros. - rewrite vforall_forall in closed. - now specialize (H1 (closed x0)). - - Case "DMatrix"%string. - unfold lift, matrixo_to_omatrix. - case_eq (vartlookup grad_env (s, DTfloat)); [intros|tauto]. - match_option. - + specialize (apply vectoro_to_ovector_forall_some_f eqq); intros. - symmetry. - apply vectoro_to_ovector_forall_some_b_strong; intros. - apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H1 i); simpl in H1. - specialize (apply vectoro_to_ovector_forall_some_f H1); intros. - specialize (H2 i0); simpl in H2. - unfold two_matrix_env_iter_alt. - specialize (H i i0 grad_env); simpl in H. - rewrite vforall_forall in closed. - assert (closedb := closed). - specialize (closed i). - rewrite vforall_forall in closed. - specialize (closed i0). - specialize (H vin closed); simpl in H. - rewrite H2 in H; simpl in H. - rewrite H0 in H; simpl in H. - unfold lift in H. - match_option_in H. - rewrite H. - destruct i. - destruct i0. - unfold lift; simpl. - generalize (list_env_iter_total_fun - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv - σ (x i j) env0 - (UnitMatrix n m (exist (fun n' : nat => n' < n) x0 l) - (exist (fun n' : nat => n' < m) x1 l0) i j)) - (Some env) - (bounded_seq0 m)) - grad_env (bounded_seq0 n)); intros. - cut_to H3. - match_option; [|tauto]. - * generalize (list_env_iter_mat_delta σ x s grad_env - (exist (fun n' : nat => n' < n) x0 l) - (exist (fun n' : nat => n' < m) x1 l0) d). - intros. - simpl in H4. - specialize (H4 H0). - cut_to H4. - rewrite eqq0 in H4. - rewrite eqq1 in H4. - unfold lift in H4. - symmetry; trivial. - intros. - specialize (closedb i). - rewrite vforall_forall in closedb. - apply closedb. - * intros. - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - specialize (closedb a). - rewrite vforall_forall in closedb. - apply closedb. - + specialize (vectoro_to_ovector_exists_None eqq); intros. - destruct H1. - specialize (vectoro_to_ovector_exists_None e); intros. - destruct H1. - symmetry. - generalize (eval_deriv_genvar_fully_closed_not_none σ (x x0 x1) [mk_env_entry (s, DTfloat) 1%R]); intros. - rewrite vforall_forall in closed. - specialize (closed x0). - rewrite vforall_forall in closed. - now specialize (H1 (closed x1)). - - Case "Var"%string. - unfold equiv_dec, vart_eqdec; simpl. - destruct (vart_dec v (s, DTfloat)). - + destruct v. - inversion e. - subst; simpl. - refl_simpler; simpl. - case_eq (vartlookup grad_env (s, DTfloat)); [intros |tauto]. - simpl; f_equal. - symmetry. - now apply subvar_addvar_scalar_eq. - + case_eq (vartlookup grad_env (s, DTfloat)); [intros|tauto]. - destruct v. - unfold snd. - destruct d0. - * simpl. - unfold lift. - destruct (vartlookup grad_env (s0, DTfloat)). - -- f_equal; simpl; symmetry. - now apply subvar_addvar_scalar_neq. - -- f_equal; symmetry. - unfold subvar; simpl. - rewrite H; lra. - * simpl; symmetry. - apply vectoro_to_ovector_forall_some_b_strong; intros; unfold lift. - destruct (vartlookup grad_env (s0, DTVector n0)). - -- f_equal; simpl. - unfold ConstVector. - now apply subvar_addvar_scalar_neq. - -- f_equal; unfold ConstVector. - unfold subvar; simpl. - rewrite H; lra. - * simpl; symmetry. - unfold matrixo_to_omatrix. - apply vectoro_to_ovector_forall_some_b_strong; intros. - apply vectoro_to_ovector_forall_some_b_strong; intros. - unfold lift. - destruct (vartlookup grad_env (s0, DTMatrix m n0)). - -- f_equal; unfold ConstMatrix. - now apply subvar_addvar_scalar_neq. - -- f_equal; unfold ConstMatrix. - unfold subvar; simpl. - rewrite H; lra. - - Case "Plus"%string. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H. - case_eq (df_eval_backprop_deriv σ l grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - destruct closed. - specialize (H H1). - invcs H. - { specialize (H0 d1). - cut_to H0; - [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))| apply H2]. - match_option - ; rewrite eqq2 in H0 - ; simpl in *. - - match_option_in H0. - unfold lift in H0. - match_option_in H0. - invcs H0. - simpl. - f_equal. - unfold subvar; simpl. - rewrite eqq3. - match_option; lra. - - match_option_in H0; simpl. - + unfold lift in *. - match_option_in H0. - + elim (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat) vin eqq3). - } - + destruct closed. - specialize (H H1). - congruence. - + destruct closed. - specialize (H H1). - congruence. - + destruct closed. - specialize (H H1). - case_eq (df_eval_backprop_deriv σ l grad_env 1%R) - ; [intros ? eqq2 | intros eqq2]. - rewrite eqq2 in H. - simpl in *. - congruence. - rewrite eqq2 in H. - now apply H. - - Case "Minus"%string. - destruct closed. - specialize (H grad_env vin H1). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H. - case_eq (df_eval_backprop_deriv σ l grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - invcs H. - { specialize (H0 d1). - cut_to H0; - [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - | apply H2]. - match_option - ; rewrite eqq2 in H0 - ; simpl in *. - - match_option_in H0. - unfold lift in H0. - match_option_in H0. - invcs H0. - simpl. - f_equal. - unfold subvar; simpl. - rewrite eqq3. - match_option. - + case_eq (df_eval_backprop_deriv σ r d1 (- (1))%R); intros. - * unfold lift. - f_equal. - match_option. - -- generalize (scalarMult_backprop_grad_scalar σ r s d1 d1 1%R (-1)%R) - ; intros; simpl in H0. - cut_to H0. - ++ unfold df_eval_backprop_delta in H0. - rewrite eqq3 in H0; unfold lift in H0; simpl in H0. - replace (-1 * 1)%R with (- (1))%R in H0 by lra. - rewrite H, eqq4 in H0; inversion H0. - unfold subvar in H0; simpl in H0. - rewrite eqq6, eqq5 in H0. - inversion H0. - lra. - ++ congruence. - ++ congruence. - ++ simpl; replace (-1 * 1)%R with (- (1))%R by lra; congruence. - ++ congruence. - -- generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (s,DTfloat)) - ; intros. - cut_to H0; congruence. - * generalize (backprop_deriv_fully_closed_not_none σ r d1 (- (1))%R); intros. - destruct H0; trivial. - + generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s,DTfloat)) - ; intros. - cut_to H; congruence. - - match_option_in H0. - + unfold lift. - unfold lift in H0. - match_option_in H0. - match_option. - specialize (df_eval_backprop_deriv_preserves_lookup_not_none eqq5). - intros. - specialize (H (s, DTfloat)). - generalize (backprop_deriv_fully_closed_not_none σ r d1 (1%R)); intros. - now destruct H3. - + generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s,DTfloat)) - ; intros. - cut_to H; congruence. - } - + unfold lift in H; simpl in H. - match_option_in H. - - Case "Times"%string. - destruct closed. - specialize (H grad_env vin H1). - simpl in *. - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H3 H1). - match_case; [intros|tauto]. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - generalize (eval_fully_closed_not_none σ r); intros. - specialize (H5 H2). - case_eq (df_eval σ r); [intros|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - case_eq (df_eval_backprop_deriv σ l grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - generalize (scalarMult_backprop_grad_scalar σ l s grad_env grad_env 1%R d1) - ; intros; simpl in H7. - unfold df_eval_backprop_delta in H7. - rewrite eqq1 in H7; simpl in H7. - specialize (H7 vin vin). - rewrite eqq0 in H7. - generalize (backprop_deriv_fully_closed_not_none σ l grad_env (d1 * 1)%R); intros. - specialize (H8 H1); specialize (H7 H8). - cut_to H7; try discriminate. - invcs H. - { specialize (H0 d3). - cut_to H0; - [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - | apply H2]. - unfold lift in H7; simpl in H7. - match_option_in H7. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - ;intros. - specialize (H vin). - generalize (backprop_deriv_fully_closed_not_none σ r d3 1%R); intros. - specialize (H9 H2). - case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence]. - rewrite H10 in H0. - unfold lift; simpl. - unfold lift in H0. - match_option_in H0. - - generalize (scalarMult_backprop_grad_scalar σ r s d2 d3 1%R d) - ; intros. - unfold df_eval_backprop_delta in H11; simpl in H11. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat));intros. - specialize (H12 vin). - case_eq (vartlookup d2 (s, DTfloat)); [intros | congruence]. - specialize (H11 H12 H). - rewrite H10, H13 in H11. - generalize (backprop_deriv_fully_closed_not_none σ r d2 (d * 1)%R); intros. - specialize (H14 H2); specialize (H11 H14 H9). - unfold lift in H11. - match_option_in H11; [|congruence]; f_equal. - rewrite (split_subvar d2 d7 d0 d6); trivial. - match_option_in H0; invcs H0. - rewrite eqq5 in H11; invcs H11; invcs H7. - lra. - - now match_option_in H0. - } - unfold lift in H. - match_case_in H; intros. - rewrite H7 in H; congruence. - generalize (backprop_deriv_fully_closed_not_none σ l grad_env 1%R); intros. - now specialize (H8 H1). - - Case "Divide"%string. - destruct closed. - specialize (H grad_env vin H1). - simpl in *. - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H3 H1). - match_case; [intros|tauto]. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - generalize (eval_fully_closed_not_none σ r); intros. - specialize (H5 H2). - case_eq (df_eval σ r); [intros|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - case_eq (df_eval_backprop_deriv σ l grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - generalize (scalarMult_backprop_grad_scalar σ l s grad_env grad_env 1%R (1 / d1)%R) - ; intros; simpl in H7. - unfold df_eval_backprop_delta in H7. - rewrite eqq1 in H7; simpl in H7. - specialize (H7 vin vin). - rewrite eqq0 in H7. - generalize (backprop_deriv_fully_closed_not_none σ l grad_env (1 / d1 * 1)%R); intros. - specialize (H8 H1) ; specialize (H7 H8). - cut_to H7; try discriminate. - replace (1 / d1 * 1)%R with (1/d1)%R in H7 by lra. - invcs H. - { specialize (H0 d3). - cut_to H0; - [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - | apply H2]. - unfold lift in H7; simpl in H7. - match_option_in H7. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - ;intros. - specialize (H vin). - generalize (backprop_deriv_fully_closed_not_none σ r d3 1%R); intros. - specialize (H9 H2). - case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence]. - rewrite H10 in H0. - unfold lift; simpl. - unfold lift in H0. - match_option_in H0. - - generalize (scalarMult_backprop_grad_scalar σ r s d2 d3 1%R (- d / (d1 * d1))%R) - ; intros; simpl in H11. - unfold df_eval_backprop_delta in H11; simpl in H11. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat));intros. - specialize (H12 vin). - case_eq (vartlookup d2 (s, DTfloat)); [intros | congruence]. - specialize (H11 H12 H). - rewrite H10, H13 in H11. - generalize (backprop_deriv_fully_closed_not_none σ r d2 (- d / (d1 * d1) * 1)%R); intros. - specialize (H14 H2); specialize (H11 H14 H9). - unfold lift in H11. - match_option_in H11; [|congruence]; f_equal. - rewrite (split_subvar d2 d7 d0 d6); trivial. - match_option_in H0; invcs H0. - rewrite eqq5 in H11; invcs H11; invcs H7. - lra. - - now match_option_in H0. - } - unfold lift in H. - match_case_in H; intros. - rewrite H7 in H; congruence. - generalize (backprop_deriv_fully_closed_not_none σ l grad_env 1%R); intros. - now specialize (H8 H1). - - Case "Square"%string. - specialize (H grad_env vin closed). - simpl in *. - generalize (eval_fully_closed_not_none σ e); intros. - specialize (H0 closed). - match_case; [intros|tauto]. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - + case_eq (df_eval_backprop_deriv σ e grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (2 * d)%R) - ; intros; simpl in H2. - unfold df_eval_backprop_delta in H2. - rewrite eqq1 in H2; simpl in H2. - specialize (H2 vin vin). - rewrite eqq0 in H2. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env (2 * d * 1)%R); intros. - specialize (H3 closed); specialize (H2 H3). - cut_to H2; try discriminate. - invcs H. - unfold lift in H2; match_option_in H2. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - ;intros. - unfold lift; now rewrite H2. - + unfold lift in H. - match_case_in H; intros. - rewrite H2 in H; congruence. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros. - now specialize (H3 closed). - - Case "Exp"%string. - specialize (H grad_env vin closed). - simpl in *. - generalize (eval_fully_closed_not_none σ e); intros. - specialize (H0 closed). - match_case; [intros|tauto]. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - + case_eq (df_eval_backprop_deriv σ e grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (exp d)%R) - ; intros; simpl in H2. - unfold df_eval_backprop_delta in H2. - rewrite eqq1 in H2; simpl in H2. - specialize (H2 vin vin). - rewrite eqq0 in H2. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env (exp d * 1)%R); intros. - specialize (H3 closed); specialize (H2 H3). - cut_to H2; try discriminate. - invcs H. - replace (1 * exp d)%R with (exp d * 1)%R by lra. - unfold lift in H2; match_option_in H2. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - ;intros. - unfold lift; rewrite H2. - f_equal; lra. - + unfold lift in H. - match_case_in H; intros. - rewrite H2 in H; congruence. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros. - now specialize (H3 closed). - - Case "Log"%string. - specialize (H grad_env vin closed). - simpl in *. - generalize (eval_fully_closed_not_none σ e); intros. - specialize (H0 closed). - match_case; [intros|tauto]. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - + case_eq (df_eval_backprop_deriv σ e grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (1 / d)%R) - ; intros; simpl in H2. - unfold df_eval_backprop_delta in H2. - rewrite eqq1 in H2; simpl in H2. - specialize (H2 vin vin). - rewrite eqq0 in H2. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env (1 / d * 1)%R); intros. - specialize (H3 closed); specialize (H2 H3). - cut_to H2; try discriminate. - invcs H. - replace (1 / d)%R with (1 / d * 1)%R by lra. - unfold lift in H2; match_option_in H2. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - ;intros. - unfold lift; rewrite H2. - f_equal; lra. - + unfold lift in H. - match_case_in H; intros. - rewrite H2 in H; congruence. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros. - now specialize (H3 closed). - - Case "Abs"%string. - specialize (H grad_env vin closed). - simpl in *. - generalize (eval_fully_closed_not_none σ e); intros. - specialize (H0 closed). - match_case; [intros|tauto]. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - + case_eq (df_eval_backprop_deriv σ e grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (sign d)%R) - ; intros; simpl in H2. - unfold df_eval_backprop_delta in H2. - rewrite eqq1 in H2; simpl in H2. - specialize (H2 vin vin). - rewrite eqq0 in H2. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env (sign d * 1)%R); intros. - specialize (H3 closed); specialize (H2 H3). - cut_to H2; try discriminate. - invcs H. - replace (1 * (@sign floatish_R d))%R with ((@sign floatish_R d) * 1)%R by lra. - unfold lift in H2; match_option_in H2. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - ;intros. - unfold lift; rewrite H2. - f_equal; lra. - + unfold lift in H. - match_case_in H; intros. - rewrite H2 in H; congruence. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros. - now specialize (H3 closed). - - Case "Sign"%string. - specialize (H grad_env vin closed). - simpl in *. - generalize (eval_fully_closed_not_none σ e); intros. - specialize (H0 closed). - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - + case_eq (df_eval_backprop_deriv σ e grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (0)%R) - ; intros H1; simpl in H1. - unfold df_eval_backprop_delta in H1. - rewrite eqq1 in H1; simpl in H1. - specialize (H1 vin vin). - rewrite eqq0 in H1. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env (0 * 1)%R); intros. - specialize (H2 closed); specialize (H1 H2). - cut_to H1; try discriminate. - invcs H. - replace (0)%R with (0 * 1)%R by lra. - unfold lift in H1; match_option_in H1. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - ;intros. - unfold lift; rewrite H1. - f_equal; lra. - + unfold lift in H. - match_case_in H; intros. - rewrite H1 in H; congruence. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros. - now specialize (H2 closed). - - Case "PSign"%string. - specialize (H grad_env vin closed). - simpl in *. - generalize (eval_fully_closed_not_none σ e); intros. - specialize (H0 closed). - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - + case_eq (df_eval_backprop_deriv σ e grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (0)%R) - ; intros H1; simpl in H1. - unfold df_eval_backprop_delta in H1. - rewrite eqq1 in H1; simpl in H1. - specialize (H1 vin vin). - rewrite eqq0 in H1. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env (0 * 1)%R); intros. - specialize (H2 closed); specialize (H1 H2). - cut_to H1; try discriminate. - invcs H. - replace (0)%R with (0 * 1)%R by lra. - unfold lift in H1; match_option_in H1. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - ;intros. - unfold lift; rewrite H1. - f_equal; lra. - + unfold lift in H. - match_case_in H; intros. - rewrite H1 in H; congruence. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros. - now specialize (H2 closed). - - Case "Max"%string. - destruct closed. - specialize (H grad_env vin H1). - specialize (H0 grad_env vin H2). - simpl in *. - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H3 H1). - match_case; [intros|tauto]. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - generalize (eval_fully_closed_not_none σ r); intros. - specialize (H5 H2). - case_eq (df_eval σ r); [intros|tauto]. - rewrite eqq0 in H. - rewrite eqq0 in H0. - destruct (Rle_dec d d1). - + apply H0. - + apply H. - - Case "VectorDot"%string. - destruct closed. - specialize (H grad_env vin H1). - simpl in *. - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H3 H1). - match_case; [intros|tauto]. - generalize (eval_fully_closed_not_none σ r); intros. - specialize (H5 H2). - case_eq (df_eval σ r); [intros | tauto]. - replace (fun rv : R => (rv * 1)%R) with id. - rewrite vmap_id; rewrite vmap_id. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto]. - symmetry in H. - generalize (backprop_deriv_fully_closed_not_none σ l grad_env d0); intros. - specialize (H7 H1). - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq1 in H. - specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H8. - match_option. - + match_option; [|tauto]. - generalize (backprop_deriv_fully_closed_not_none σ r d4 d); intros. - specialize (H9 H2). - unfold lift. - match_option; [|tauto]. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat));intros. - specialize (H10 vin). - specialize (H0 d4 H10 H2). - rewrite eqq0 in H0. - match_option_in H0. - symmetry in H0. - unfold lift in H0. - specialize (apply vectoro_to_ovector_forall_some_f H0); intros. - simpl in H11. - f_equal. - unfold lift in H8. - rewrite (split_subvar d4 d5 d1 d6); trivial. - rewrite <- vsum_plus. - generalize (df_eval_backprop_delta_by_unit_partvec σ l s grad_env d0 - (fun i => (d2 i * d0 i)%R)); intros. - specialize (H12 H1 vin). - generalize (df_eval_backprop_delta_by_unit_partvec σ r s d4 d - (fun i => (d i * d3 i)%R)); intros. - specialize (H13 H2 H10). - cut_to H12. - cut_to H13. - * unfold df_eval_backprop_delta in H12. - rewrite eqq1 in H12. - unfold df_eval_backprop_delta in H13. - rewrite eqq4 in H13. - unfold lift in H12. - unfold lift in H13. - rewrite eqq2 in H12. - rewrite eqq3 in H13. - invcs H12; invcs H13. - rewrite H14, H15; lra. - * intros. - replace (@scaleUnitVector (@float floatish_R) n i (d i) (IZR Z0)) with - (scalarMult (DTVector n) (d i) (UnitVector n i)). - rewrite scalarMult_backprop_grad_scalar with (grad_env1 := d4) (grad_env2:=d4);trivial. - unfold df_eval_backprop_delta, lift. - rewrite eqq4. - specialize (H11 i). - destruct i; simpl in H11. - simpl_closed_backprop. - f_equal. - rewrite eqq5 in H11. - invcs H11. - rewrite H15. - now unfold scalarMult; simpl. - apply backprop_deriv_fully_closed_not_none; trivial. - apply backprop_deriv_fully_closed_not_none; trivial. - unfold scalarMult, UnitVector, scaleUnitVector; simpl. - apply functional_extensionality; intros. - destruct (equiv_dec (` x) (` i)); lra. - * intros. - replace (@scaleUnitVector (@float floatish_R) n i (d0 i) (IZR Z0)) with - (scalarMult (DTVector n) (d0 i) (UnitVector n i)). - rewrite scalarMult_backprop_grad_scalar with (grad_env1 := grad_env) (grad_env2:=grad_env);trivial. - unfold df_eval_backprop_delta, lift. - rewrite eqq1. - specialize (H8 i). - destruct i; simpl in H8. - simpl_closed_backprop. - f_equal. - rewrite eqq5 in H8. - invcs H8. - rewrite H15. - now unfold scalarMult; simpl; lra. - apply backprop_deriv_fully_closed_not_none; trivial. - apply backprop_deriv_fully_closed_not_none; trivial. - unfold scalarMult, UnitVector, scaleUnitVector; simpl. - apply functional_extensionality; intros. - destruct (equiv_dec (` x) (` i)); lra. - + generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros. - now specialize (H9 H2). - + generalize (eval_deriv_genvar_fully_closed_not_none σ l [mk_env_entry (s, DTfloat) 1%R]); intros. - now specialize (H8 H1). - + apply FunctionalExtensionality.functional_extensionality. - intros; unfold id; lra. - - Case "VectorSum"%string. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H; [|tauto|]. - + specialize (H closed). - symmetry in H. - generalize (vectoro_to_ovector_forall_some_f H); intros HH; clear H. - replace (@lift (@df_env floatish_R) R - (fun e : df_env => subvar (s, DTfloat) e d0) - (df_eval_backprop_deriv σ v grad_env - (ConstVector n 1%R))) - with - (df_eval_backprop_delta σ v (s, DTfloat) grad_env - (ConstVector n 1%R)). - rewrite (df_eval_backprop_delta_by_unit_parts _ _ _ _ d); trivial. - * intros. - specialize (HH i). - unfold df_eval_backprop_delta. - rewrite eqq0. - replace (UnitVector n i) with - (coerce - (df_eval_backward_gen_top_obligation_2 UnitAnn (DTVector n) v n eq_refl i) - (UnitVector n i)); trivial. - destruct i. - now simpl. - * unfold df_eval_backprop_delta. - now rewrite eqq0. - + specialize (H closed). - symmetry in H. - specialize (vectoro_to_ovector_exists_None H); intros. - destruct H0. - unfold lift in e. - match_option_in e. - generalize (backprop_deriv_fully_closed_not_none - σ v grad_env - (coerce - (df_eval_backward_gen_top_obligation_2 - UnitAnn (DTVector n) v n eq_refl x) - (UnitVector n x))); intros. - specialize (H0 closed); tauto. - - Case "MatrixSum"%string. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H; [|tauto|]. - + specialize (H closed). - symmetry in H. - specialize (apply vectoro_to_ovector_forall_some_f H); intros HH; simpl in HH. - replace (@lift (@df_env floatish_R) R - (fun e : df_env => subvar (s, DTfloat) e d0) - (df_eval_backprop_deriv σ v grad_env - (ConstMatrix m n 1%R))) - with - (df_eval_backprop_delta σ v (s, DTfloat) grad_env - (ConstMatrix m n 1%R)). - rewrite (df_eval_backprop_delta_by_unit_parts_mat _ _ _ _ d); trivial. - * intros. - specialize (HH i). - specialize (apply vectoro_to_ovector_forall_some_f HH); intros. - specialize (H0 j); simpl in H0. - unfold df_eval_backprop_delta. - rewrite eqq0. - replace (UnitMatrix m n i j) with - (coerce - (df_eval_backward_gen_top_obligation_3 - UnitAnn (DTMatrix m n) v m n eq_refl i j) - (UnitMatrix m n i j)); trivial. - destruct i; destruct j. - now simpl. - * unfold df_eval_backprop_delta. - now rewrite eqq0. - + specialize (H closed). - symmetry in H. - specialize (vectoro_to_ovector_exists_None H); intros. - destruct H0. - unfold lift in e. - specialize (vectoro_to_ovector_exists_None e); intros. - destruct H0. - match_option_in e0. - generalize (backprop_deriv_fully_closed_not_none - σ v grad_env - (coerce - (df_eval_backward_gen_top_obligation_3 - UnitAnn (DTMatrix m n) v m n eq_refl x x0) - (UnitMatrix m n x x0))); intros. - specialize (H0 closed); tauto. - - Case "VectorElem"%string. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H; [|tauto|]. - + specialize (H closed). - symmetry in H. - specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H. - unfold lift. - replace (fun k : {n' : nat | n' < n} => if equiv_dec (` k) (` i) then 1%R else 0%R) - with - (UnitVector n i). - generalize (backprop_deriv_fully_closed_not_none - σ l grad_env (UnitVector n i)); intros. - specialize (H1 closed). - specialize (H0 i). - match_option; symmetry; [|tauto]. - destruct i; simpl in *. - unfold lift in H0; f_equal. - rewrite eqq1 in H0. - now invcs H0. - unfold UnitVector. - apply FunctionalExtensionality.functional_extensionality; intros. - trivial. - + specialize (H closed). - symmetry in H. - specialize (vectoro_to_ovector_exists_None H); intros. - destruct H0. - unfold lift in e. - match_option_in e. - generalize (backprop_deriv_fully_closed_not_none - σ l grad_env - (coerce - (df_eval_backward_gen_top_obligation_2 - UnitAnn (DTVector n) l n eq_refl x) - (UnitVector n x))); intros. - specialize (H0 closed); tauto. - - Case "MatrixElem"%string. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H; [|tauto|]. - + specialize (H closed). - symmetry in H. - specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H. - unfold lift. - replace - (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) => - if equiv_dec (` k1) (` i) then - if equiv_dec (` k2) (` j) then 1%R else 0%R else 0%R) - with (UnitMatrix m n i j). - generalize (backprop_deriv_fully_closed_not_none - σ l grad_env (UnitMatrix m n i j)); intros. - specialize (H1 closed). - specialize (H0 i). - specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H2. - specialize (H2 j). - match_option; symmetry; [|tauto]. - destruct i; destruct j; simpl in *. - unfold lift in H2; f_equal. - rewrite eqq1 in H2. - now invcs H2. - unfold UnitMatrix. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - trivial. - + specialize (H closed). - symmetry in H. - specialize (vectoro_to_ovector_exists_None H); intros. - destruct H0. - unfold lift in e. - specialize (vectoro_to_ovector_exists_None e); intros. - destruct H0. - match_option_in e0. - generalize (backprop_deriv_fully_closed_not_none - σ l grad_env - (coerce - (df_eval_backward_gen_top_obligation_3 - UnitAnn (DTMatrix m n) l m n eq_refl x x0) - (UnitMatrix m n x x0))); intros. - specialize (H0 closed); tauto. - - Case "MatrixVectorMult"%string. - destruct closed. - specialize (H grad_env vin H1). - simpl in *. - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H3 H1). - match_case; [intros|tauto]. - generalize (eval_fully_closed_not_none σ r); intros. - specialize (H5 H2). - case_eq (df_eval σ r); [intros | tauto]. - assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - match_option; [|tauto]. - assert (df_eval_deriv_genvar σ r [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - match_option; [|tauto]. - match_option; [|tauto]. - unfold lift. - symmetry. - apply vectoro_to_ovector_forall_some_b_strong; intros. - simpl_closed_backprop. - simpl_closed_backprop. - f_equal. - rewrite eqq1, eqq in H. - unfold lift in H; symmetry in H. - specialize (vectoro_to_ovector_forall_some_f H); intros. - specialize (H0 env). - simpler2. - cut_to H0; try congruence. - rewrite eqq0 in H0. - rewrite (split_subvar env env0 d3 val); trivial. - specialize (H9 i); simpl in H9. - specialize (vectoro_to_ovector_forall_some_f H9); intros. - replace (@vsum floatish_R n (fun j : {n' : nat | n' < n} => (d i j * d2 j + d1 i j * d0 j)%R)) - with ((vsum (fun j => (d i j * d2 j)%R)) + (vsum (fun j => (d1 i j * d0 j)%R)))%R - ; [|rewrite vsum_plus; f_equal]. - unfold lift in H0; symmetry in H0. - specialize (vectoro_to_ovector_forall_some_f H0); intros. - simpl in H11; simpl in H10. - destruct i; simpl in eqq2; simpl in eqq3. - unfold matrix_vector_mult in eqq3; simpl in eqq3. - unfold UnitVector in eqq2; simpl in eqq2. - replace (fun i : {n' : nat | n' < n} => - (@vsum floatish_R m - (fun j : {n' : nat | n' < m} => - (d j i * (@UnitVector floatish_R m (exist (fun n' : nat => (n' < m)%nat) x l0) - j)%R)%R))) - with (d (exist (fun n' : nat => (n' < m)%nat) x l0)) in eqq3. - + generalize (df_eval_backprop_delta_by_unit_partvec - σ r s env (d (exist (fun n' : nat => n' < m) x l0 )) - (fun i => ((d (exist (fun n' : nat => (n' < m)%nat) x l0 ) i) * (d2 i))%R)) - ; intros. - specialize (H12 H2). - cut_to H12; try congruence. - * unfold df_eval_backprop_delta in H12. - rewrite eqq3, eqq4 in H12. - unfold lift in H12. - invcs H12. - apply Rplus_eq_compat_l. - generalize (df_eval_backprop_delta_by_unit_partmat - σ l s grad_env - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - ((if equiv_dec (` i) x then 1 else 0) * d0 j)%R) - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - ((if equiv_dec (` i) x then 1 else 0) * (d1 i j) * (d0 j))%R)) - ; intros. - specialize (H12 H1). - cut_to H12; try congruence. - -- unfold df_eval_backprop_delta in H12. - rewrite eqq1 in H12. - unfold lift in H12. - rewrite eqq2 in H12. - invcs H12. - rewrite H15. - replace - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - ((if equiv_dec (` i) x then 1 else 0) * d1 i j * d0 j)%R) with - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (UnitVector m (exist _ x l0) i * d1 i j * d0 j)%R). - now rewrite msum_unitvector. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - now unfold UnitVector; simpl. - -- intros. - rewrite scalarMult_backprop_grad_scalar with (grad_env2:=grad_env) - ;trivial; try congruence. - unfold lift. - unfold df_eval_backprop_delta. - rewrite eqq1. - unfold lift. - specialize (vectoro_to_ovector_forall_some_f H); intros. - specialize (H13 i); simpl in H13. - specialize (vectoro_to_ovector_forall_some_f H13); intros. - specialize (H15 j); intros; simpl in H15. - destruct i; destruct j; simpl in H15. - match_option_in H15. - f_equal; simpl. - destruct (equiv_dec x0 x). - invcs H15. - rewrite H17; lra. - lra. - apply backprop_deriv_fully_closed_not_none; trivial. - apply backprop_deriv_fully_closed_not_none; trivial. - * intros. - replace (@scaleUnitVector (@float floatish_R) n i (d (@exist nat (fun n' : nat => lt n' m) x l0) i) (IZR Z0)) with - (scalarMult (DTVector n) (d (exist (fun n' : nat => n' < m) x l0) i) (UnitVector n i)). - rewrite scalarMult_backprop_grad_scalar with (grad_env1 := env) (grad_env2:=env);trivial; try congruence. - unfold scalarMult; simpl. - unfold lift; simpl. - specialize (H11 i). - destruct i; simpl in H11. - match_option_in H11. - unfold df_eval_backprop_delta. - rewrite eqq4,eqq5. - unfold lift; f_equal. - now invcs H11. - apply backprop_deriv_fully_closed_not_none; trivial. - apply backprop_deriv_fully_closed_not_none; trivial. - unfold scalarMult, scaleUnitVector. - apply FunctionalExtensionality.functional_extensionality; intros. - unfold UnitVector. - destruct (equiv_dec (` x0) (` i)); simpl; lra. - + apply FunctionalExtensionality.functional_extensionality; intros. - now rewrite vsum_unitvector. - - Case "MatrixVectorAdd"%string. - destruct closed. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H; [|tauto|]. - + specialize (H H1). - symmetry in H. - specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H3. - match_option; symmetry. - * apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H3 i). - specialize (apply vectoro_to_ovector_forall_some_f H3); intros; simpl in H4. - apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H4 i0). - unfold lift in H4; unfold lift. - match_option_in H4. - unfold lift in H0. - rewrite eqq1 in H0. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat)) - ;intros. - specialize (H5 vin). - specialize (H0 d2 H5 H2). - match_option_in H0. - symmetry in H0. - specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H6. - specialize (H6 i). - destruct i; destruct i0; simpl; simpl in eqq2; simpl in H6. - rewrite eqq2. - match_option_in H6. - generalize (list_env_iter_total_fun - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv - σ r env - (transpose - (UnitMatrix m n (exist (fun n' : nat => n' < m) x l0) - (exist (fun n' : nat => n' < n) x0 l1)) i)) - d2 (bounded_seq0 n)); intros. - cut_to H7. - -- case_eq (list_env_iter - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv - σ r env - (@transpose R m n - (UnitMatrix m n (exist (fun n' : nat => n' < m) x l0) - (exist (fun n' : nat => n' < n) x0 l1)) i)) - (Some d2) - (bounded_seq0 n)); [intros |tauto]. - f_equal. - rewrite (split_subvar d2 d5 d0 d3); trivial. - invcs H4. - invcs H6. - generalize (list_env_iter_matvec_delta - σ r s d2 - (exist (fun n' : nat => n' < n) x0 l1) - (exist (fun n' : nat => n' < m) x l0) d3). - intros. - specialize (H4 eqq3 H2). - rewrite eqq4 in H4. - replace (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv - σ r env - (UnitMatrix n m (exist (fun n' : nat => n' < n) x0 l1) - (exist (fun n' : nat => n' < m) x l0) i)) with - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv - σ r env - (@transpose R m n - (UnitMatrix m n (exist (fun n' : nat => n' < m) x l0) - (exist (fun n' : nat => n' < n) x0 l1)) i)) in H4. - ++ rewrite H8 in H4. - unfold lift in H4. - invcs H4; lra. - ++ apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - f_equal. - unfold transpose, UnitMatrix. - apply FunctionalExtensionality.functional_extensionality; intros. - simpl. - match_case; intros. - match_case; intros. - -- intros. - apply backprop_deriv_fully_closed_not_none; trivial. - * assert (df_eval_deriv_genvar σ r [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - - Case "MatrixMult"%string. - destruct closed. - specialize (H grad_env vin H1). - simpl in *. - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H3 H1). - match_case; [intros|tauto]. - generalize (eval_fully_closed_not_none σ r); intros. - specialize (H5 H2). - case_eq (df_eval σ r); [intros | tauto]. - assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - match_option; [|tauto]. - assert (df_eval_deriv_genvar σ r [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - match_option; [|tauto]. - match_option; [|tauto]. - unfold lift. - symmetry. - apply vectoro_to_ovector_forall_some_b_strong; intros. - apply vectoro_to_ovector_forall_some_b_strong; intros. - simpl_closed_backprop. - simpl_closed_backprop. - f_equal. - rewrite eqq1, eqq in H. - unfold lift in H; symmetry in H. - specialize (vectoro_to_ovector_forall_some_f H); intros. - specialize (H0 env). - simpler2. - cut_to H0; try congruence. - rewrite eqq0 in H0. - rewrite (split_subvar env env0 d3 val); trivial. - specialize (H9 i); simpl in H9. - specialize (vectoro_to_ovector_forall_some_f H9); intros. - replace (@vsum floatish_R p (fun j => (d i j * d2 j i0 + d1 i j * d0 j i0)%R)) - with ((vsum (fun j => (d i j * d2 j i0)%R)) + (vsum (fun j => (d1 i j * d0 j i0)%R)))%R - ; [|rewrite vsum_plus; f_equal]. - unfold lift in H0; symmetry in H0. - specialize (vectoro_to_ovector_forall_some_f H0); intros. - simpl in H10; simpl in H11. - destruct i; destruct i0; simpl in eqq2; simpl in eqq3. - unfold matrix_mult,UnitMatrix in eqq3; simpl in eqq3. - unfold matrix_mult,UnitMatrix in eqq2; simpl in eqq2. - generalize (df_eval_backprop_delta_by_unit_partmat - σ l s grad_env - (fun (i : {n' : nat | n' < m}) (k : {m' : nat | m' < p}) => - vsum - (fun j : {n' : nat | n' < n} => - ((if equiv_dec (` i) x then - if equiv_dec (` j) x0 then 1 else 0 else 0) * d0 k j)%R)) - (fun (i : {n' : nat | n' < m}) (k : {m' : nat | m' < p}) => - vsum - (fun j : {n' : nat | n' < n} => - ((if equiv_dec (` i) x then - if equiv_dec (` j) x0 then 1 else 0 else 0) * - (d1 i k) * d0 k j)%R))) - ; intros. - specialize (H12 H1 vin). - cut_to H12; try congruence. - + unfold df_eval_backprop_delta in H12. - rewrite eqq1 in H12. - unfold lift in H12. - rewrite eqq2 in H12. - invcs H12. - rewrite H14. - assert (msum - (fun (i : {n' : nat | (n' < m)%nat}) (k : {m' : nat | (m' < p)%nat}) => - vsum - (fun j : {n' : nat | (n' < n)%nat} => - (if equiv_dec (` i) x then - if equiv_dec (` j) x0 then 1 else 0 else 0) - * d1 i k * d0 k j))%R = - vsum - (fun j : {n' : nat | (n' < p)%nat} => - d1 (exist (fun n' : nat => (n' < m)%nat) x l0) j * - d0 j (exist (fun n' : nat => (n' < n)%nat) x0 l1))%R). - * rewrite msum_transpose. - unfold msum. - f_equal. - unfold transpose; simpl. - replace (fun (i : {m' : nat | m' < p}) (j : {n' : nat | n' < m}) => - (@vsum floatish_R n - (fun j0 : {n' : nat | n' < n} => - ((if equiv_dec (` j) x then - if equiv_dec (` j0) x0 then 1 else 0 - else 0) * d1 j i * d0 i j0)%R))) with - (fun (i : {m' : nat | m' < p}) (j : {n' : nat | n' < m}) => - if equiv_dec (` j) x then - (d1 (exist _ x l0) i * d0 i (exist _ x0 l1))%R - else 0%R). - -- apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vmap_nth. - replace (fun j : {n' : nat | n' < m} => - if equiv_dec (` j) x - then - (d1 (exist (fun n' : nat => (n' < m)%nat) x l0) x1 * - d0 x1 (exist (fun m' : nat => (m' < n)%nat) x0 l1))%R - else 0%R) with - (fun j : {n' : nat | n' < m} => - ((d1 (exist (fun n' : nat => (n' < m)%nat) x l0) x1 * - d0 x1 (exist (fun m' : nat => (m' < n)%nat) x0 l1))%R * - (@UnitVector floatish_R m (exist _ x l0) j))%R). - now rewrite vsum_unitvector. - apply FunctionalExtensionality.functional_extensionality; intros. - unfold UnitVector; simpl. - destruct (equiv_dec (` x2) x); lra. - -- apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - destruct ( equiv_dec (` x2) x). - ++ replace - (fun j0 : {n' : nat | n' < n} => - ((if equiv_dec (` j0) x0 then 1 else 0) * d1 x2 x1 * d0 x1 j0)%R) - with - (fun j0 : {n' : nat | n' < n} => - ((d1 x2 x1 * d0 x1 j0) * (UnitVector n (exist _ x0 l1) j0))%R). - ** rewrite vsum_unitvector. - red in e. - subst. - destruct x2. - simpl. - erewrite index_pf_irrel; eauto. - ** apply FunctionalExtensionality.functional_extensionality; intros. - unfold UnitVector; simpl. - destruct (equiv_dec (` x3) x0); lra. - ++ rewrite <- vsum_mult. - lra. - * rewrite H12. - apply Rplus_eq_compat_r. - generalize (df_eval_backprop_delta_by_unit_partmat - σ r s env - (fun (i : {n' : nat | n' < p}) (k : {m' : nat | m' < n}) => - vsum - (fun j : {n' : nat | n' < m} => - (d j i * - (if equiv_dec (` j) x then - if equiv_dec (` k) x0 then 1 else 0 else 0))%R)) - (fun (i : {n' : nat | n' < p}) (k : {m' : nat | m' < n}) => - vsum - (fun j : {n' : nat | n' < m} => - (d j i * d2 i k * - (if equiv_dec (` j) x then - if equiv_dec (` k) x0 then 1 else 0 else 0))%R))) - ; intros. - specialize (H13 H2). - cut_to H13; try congruence. - -- unfold df_eval_backprop_delta in H13. - rewrite eqq4 in H13. - unfold lift in H13. - rewrite eqq3 in H13. - invcs H13. - rewrite H16. - unfold msum; f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vmap_nth. - replace - (fun k : {m' : nat | m' < n} => - (@vsum floatish_R m - (fun j : {n' : nat | n' < m} => - (d j x1 * d2 x1 k * - (if equiv_dec (` j) x then - if equiv_dec (` k) x0 then 1 else 0 else 0))%R))) with - (fun k : {m' : nat | m' < n} => - (vsum - (fun j => - (d j x1 * d2 x1 k * - (UnitVector m (exist _ x l0) j))%R) - * (UnitVector n (exist _ x0 l1) k))%R). - ++ rewrite vsum_unitvector. - rewrite vsum_unitvector. - lra. - ++ apply FunctionalExtensionality.functional_extensionality; intros. - rewrite Rmult_comm. - rewrite vsum_mult. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - unfold UnitVector; simpl. - destruct (equiv_dec (` x2) x0); destruct (equiv_dec (` x3) x); lra. - -- intros. - rewrite scalarMult_backprop_grad_scalar with (grad_env2:=env); trivial; try congruence. - ++ unfold lift. - unfold df_eval_backprop_delta. - rewrite eqq4. - unfold lift. - specialize (H11 i); simpl in H11. - specialize (vectoro_to_ovector_forall_some_f H11); intros. - specialize (H15 j); intros; simpl in H15. - destruct i; destruct j; simpl in H15. - match_option_in H15. - f_equal; simpl. - invcs H15. - rewrite H17. - rewrite Rmult_comm. - rewrite vsum_mult. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - ++ apply backprop_deriv_fully_closed_not_none; trivial. - ++ apply backprop_deriv_fully_closed_not_none; trivial. - + intros. - rewrite scalarMult_backprop_grad_scalar with (grad_env2:=grad_env); trivial; try congruence. - * unfold lift. - unfold df_eval_backprop_delta. - rewrite eqq1. - unfold lift. - specialize (vectoro_to_ovector_forall_some_f H i); intros; simpl in H11. - specialize (vectoro_to_ovector_forall_some_f H13 j); intros; simpl in H14. - destruct i; destruct j; simpl in H14. - match_option_in H14. - f_equal; simpl. - invcs H14. - rewrite H16. - rewrite Rmult_comm. - rewrite vsum_mult. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - * apply backprop_deriv_fully_closed_not_none; trivial. - * apply backprop_deriv_fully_closed_not_none; trivial. - - Case "VectorPlus"%string. - destruct closed. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H; [|tauto|]. - + specialize (H H1). - symmetry in H. - specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H3. - match_option; symmetry. - * apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H3 i). - unfold lift in H3; unfold lift. - match_option_in H3. - unfold lift in H0. - rewrite eqq1 in H0. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat)) - ;intros. - specialize (H4 vin). - specialize (H0 d2 H4 H2). - match_option_in H0. - symmetry in H0. - specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H5. - specialize (H5 i). - destruct i; simpl in H1; simpl; simpl in eqq2; simpl in H5. - rewrite eqq2. - generalize (backprop_deriv_fully_closed_not_none - σ r d2 - (UnitVector n (exist (fun n' : nat => n' < n) x l0))) - ; intros; specialize (H6 H2). - match_option; [|tauto]; f_equal. - rewrite eqq4 in H5; inversion H5. - rewrite (split_subvar d2 d4 d0 d3); trivial. - inversion H3; lra. - * generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros. - specialize (H4 H2). - tauto. - + specialize (H H1). - symmetry in H. - specialize (apply vectoro_to_ovector_exists_None H); intros. - destruct H3. - unfold lift in e. - match_option_in e. - generalize (backprop_deriv_fully_closed_not_none - σ l grad_env - (coerce - (df_eval_backward_gen_top_obligation_2 - UnitAnn (DTVector n) l n eq_refl x) - (UnitVector n x))); intros. - specialize (H3 H1); tauto. - - Case "VectorMinus"%string. - destruct closed. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H; [|tauto|]. - + specialize (H H1). - symmetry in H. - specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H3. - match_option; symmetry. - * apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H3 i). - unfold lift in H3; unfold lift. - match_option_in H3. - unfold lift in H0. - rewrite eqq1 in H0. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat)) - ;intros. - specialize (H4 vin). - specialize (H0 d2 H4 H2). - match_option_in H0. - symmetry in H0. - specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H5. - specialize (H5 i). - destruct i; simpl in H0; simpl; simpl in eqq2; simpl in H5. - rewrite eqq2. - generalize (backprop_deriv_fully_closed_not_none - σ r d2 - (UnitVector n (exist (fun n' : nat => n' < n) x l0))) - ; intros; specialize (H6 H2). - generalize (scalarMult_backprop_grad_scalar - σ r s d2 d2 - (UnitVector n (exist (fun n' : nat => n' < n) x l0)) - (-1)%R ); intros. - specialize (H7 H4 H4). - generalize (backprop_deriv_fully_closed_not_none - σ r d2 - (scalarMult (DTVector n) (-1)%R - (UnitVector n (exist (fun n' : nat => n' < n) x l0)))) - ; intros; specialize (H8 H2). - specialize (H7 H8 H6). - unfold df_eval_backprop_delta in H7; simpl in H7. - rewrite eqq3 in H7. - unfold lift in H7; simpl in H7. - replace (fun i : {n' : nat | n' < n} => - (- (@UnitVector floatish_R n (exist (fun n' : nat => (n' < n)%nat) x l0)) i)%R) - with - (fun i : {n' : nat | n' < n} => - (-1 * UnitVector n (exist (fun n' : nat => (n' < n)%nat) x l0) i)%R). - match_option; [|tauto]; f_equal. - rewrite eqq4 in H7. - rewrite (split_subvar d2 d4 d0 d3); trivial. - rewrite H5 in H7. - inversion H7. - rewrite H10. - inversion H3; lra. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - * generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros. - specialize (H4 H2). - tauto. - + specialize (H H1). - symmetry in H. - specialize (apply vectoro_to_ovector_exists_None H); intros. - destruct H3. - unfold lift in e. - match_option_in e. - generalize (backprop_deriv_fully_closed_not_none - σ l grad_env - (coerce - (df_eval_backward_gen_top_obligation_2 UnitAnn (DTVector n) l n eq_refl x) - (UnitVector n x))); intros. - specialize (H3 H1); tauto. - - Case "MatrixPlus"%string. - destruct closed. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H; [|tauto|]. - + specialize (H H1). - symmetry in H. - specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H1. - match_option; symmetry. - * apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H3 i). - specialize (apply vectoro_to_ovector_forall_some_f H3); intros; simpl in H4. - apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H4 i0). - unfold lift in H4; unfold lift. - match_option_in H4. - unfold lift in H0. - rewrite eqq1 in H0. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat)) - ;intros. - specialize (H5 vin). - specialize (H0 d2 H5 H2). - match_option_in H0. - symmetry in H0. - specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H6. - specialize (H6 i). - specialize (apply vectoro_to_ovector_forall_some_f H6); intros; simpl in H7. - specialize (H7 i0). - destruct i; destruct i0; simpl; simpl in eqq2; simpl in H7. - rewrite eqq2. - generalize (backprop_deriv_fully_closed_not_none - σ r d2 - (UnitMatrix n m (exist (fun n' : nat => n' < n) x l0) - (exist (fun n' : nat => n' < m) x0 l1))) - ; intros; specialize (H8 H2). - match_option; [|tauto]; f_equal. - rewrite eqq4 in H7; inversion H7. - rewrite (split_subvar d2 d4 d0 d3); trivial. - inversion H4; lra. - * generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros. - now specialize (H4 H2). - + specialize (H H1). - symmetry in H. - specialize (apply vectoro_to_ovector_exists_None H); intros. - destruct H3. - specialize (apply vectoro_to_ovector_exists_None e); intros. - destruct H3. - unfold lift in e0. - match_option_in e0. - generalize (backprop_deriv_fully_closed_not_none - σ l grad_env - (coerce - (df_eval_backward_gen_top_obligation_3 UnitAnn (DTMatrix n m) l n m eq_refl x - x0) (UnitMatrix n m x x0))); intros. - specialize (H3 H1); tauto. - - Case "MatrixMinus"%string. - destruct closed. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H; [|tauto|]. - + specialize (H H1). - symmetry in H. - specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H1. - match_option; symmetry. - * apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H3 i). - specialize (apply vectoro_to_ovector_forall_some_f H3); intros; simpl in H4. - apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H4 i0). - unfold lift in H4; unfold lift. - match_option_in H4. - unfold lift in H0. - rewrite eqq1 in H0. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat)) - ;intros. - specialize (H5 vin). - specialize (H0 d2 H5 H2). - match_option_in H0. - symmetry in H0. - specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H6. - specialize (H6 i). - specialize (apply vectoro_to_ovector_forall_some_f H6); intros; simpl in H7. - specialize (H7 i0). - destruct i; destruct i0; simpl; simpl in eqq2; simpl in H7. - rewrite eqq2. - generalize (backprop_deriv_fully_closed_not_none - σ r d2 - (UnitMatrix n m (exist (fun n' : nat => n' < n) x l0) - (exist (fun n' : nat => n' < m) x0 l1))) - ; intros; specialize (H8 H2). - - generalize (scalarMult_backprop_grad_scalar - σ r s d2 d2 - (UnitMatrix n m (exist (fun n' : nat => n' < n) x l0) - (exist (fun n' : nat => n' < m) x0 l1)) - (-1)%R ); intros. - specialize (H9 H5 H5). - generalize (backprop_deriv_fully_closed_not_none - σ r d2 - (scalarMult (DTMatrix n m) (-1)%R - (UnitMatrix n m (exist (fun n' : nat => n' < n) x l0) - (exist (fun n' : nat => n' < m) x0 l1)))) - ; intros; specialize (H10 H2). - specialize (H9 H10 H8). - unfold df_eval_backprop_delta in H9; simpl in H9. - rewrite eqq3 in H9. - unfold lift in H9; simpl in H9. - replace - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (- - (@UnitMatrix floatish_R n m (exist (fun n' : nat => (n' < n)%nat) x l0) - (exist (fun n' : nat => (n' < m)%nat) x0 l1)) i j)%R) - with - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (-1 * - UnitMatrix n m (exist (fun n' : nat => (n' < n)%nat) x l0) - (exist (fun n' : nat => (n' < m)%nat) x0 l1) i j)%R). - match_option; [|tauto]; f_equal. - rewrite eqq4 in H9. - rewrite (split_subvar d2 d4 d0 d3); trivial. - rewrite H7 in H9. - inversion H9. - rewrite H12. - inversion H4; lra. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - * generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros. - now specialize (H4 H2). - + specialize (H H1). - symmetry in H. - specialize (apply vectoro_to_ovector_exists_None H); intros. - destruct H3. - specialize (apply vectoro_to_ovector_exists_None e); intros. - destruct H3. - unfold lift in e0. - match_option_in e0. - generalize (backprop_deriv_fully_closed_not_none - σ l grad_env - (coerce - (df_eval_backward_gen_top_obligation_3 UnitAnn (DTMatrix n m) l n m eq_refl x - x0) (UnitMatrix n m x x0))); intros. - specialize (H3 H1); tauto. - - Case "VectorScalMult"%string. - destruct closed. - specialize (H grad_env vin H1). - simpl in *. - generalize (eval_fully_closed_not_none σ x); intros. - specialize (H3 H1). - match_case; [intros|tauto]. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H5 H2). - case_eq (df_eval σ l); [intros|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - case_eq (df_eval_backprop_deriv σ x grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - match_option; unfold lift; symmetry. - apply vectoro_to_ovector_forall_some_b_strong; intros. - replace (@vsum - floatish_R n - (fun j : @sig nat (fun n' : nat => lt n' n) => - Rmult (d1 j) - (@coerce (Vector R n) (Vector R n) - (@df_eval_backward_gen_top_obligation_2 - floatish_R UnitAnn - (DTVector n) - (@VectorScalMult floatish_R UnitAnn n ann x l) n - (@eq_refl definition_function_types (DTVector n)) i) - (@UnitVector floatish_R n i) j))) - with (d1 i). - generalize (scalarMult_backprop_grad_scalar - σ x s grad_env grad_env 1%R (d1 i)); intros; simpl in H7. - unfold df_eval_backprop_delta in H7. - rewrite eqq1 in H7; simpl in H7. - specialize (H7 vin vin). - rewrite eqq0 in H7. - generalize (backprop_deriv_fully_closed_not_none - σ x grad_env (d1 i * 1)%R ); intros. - specialize (H8 H1); specialize (H7 H8). - cut_to H7; try discriminate. - invcs H. - { specialize (H0 d3). - cut_to H0; - [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - | apply H2]. - unfold lift in H7; simpl in H7. - match_option_in H7. - replace (d1 i) with (d1 i * 1)%R by lra. - rewrite eqq3. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)); intros. - specialize (H vin). - case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence]. - rewrite H9 in H0. - unfold lift; simpl. - unfold lift in H0. - rewrite eqq2 in H0. - generalize (scalarMult_backprop_grad_scalar σ l s d2 d3 (UnitVector n i) d) - ; intros; simpl in H10. - unfold df_eval_backprop_delta in H10; simpl in H10. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat));intros. - specialize (H11 vin). - case_eq (vartlookup d2 (s, DTfloat)); [intros | congruence]. - specialize (H10 H11 H). - rewrite H9, H12 in H10. - symmetry in H0. - specialize (apply vectoro_to_ovector_forall_some_f H0); intros. - specialize (H13 i); simpl in H13. - generalize (backprop_deriv_fully_closed_not_none - σ l d2 - (fun i0 : {n' : nat | n' < n} => (d * UnitVector n i i0)%R)); intros. - specialize (H14 H2); specialize (H10 H14). - generalize (backprop_deriv_fully_closed_not_none - σ l d3 (UnitVector n i)); intros. - specialize (H15 H2); specialize (H10 H15). - unfold lift in H10. - match_option_in H10; [|congruence]; f_equal. - replace - (fun j : @sig nat (fun n' : nat => lt n' n) => - Rmult d - (@coerce (Vector R n) (Vector R n) - (@df_eval_backward_gen_top_obligation_2 - floatish_R UnitAnn - (DTVector n) (@VectorScalMult floatish_R UnitAnn n ann x l) n - (@eq_refl definition_function_types (DTVector n)) i) - (@UnitVector floatish_R n i) j)) - with - (fun i0 : {n' : nat | n' < n} => (d * UnitVector n i i0)%R). - rewrite eqq4. - rewrite (split_subvar d2 d7 d0 d6); trivial; f_equal. - match_option_in H10; inversion H10. - match_option_in H13. - invcs H7; invcs H13. - rewrite H17. - destruct i; simpl in eqq6. - rewrite eqq6 in eqq5. - invcs eqq5. - lra. - apply FunctionalExtensionality.functional_extensionality; intros. - now destruct i; simpl. - } - + symmetry. - destruct i; simpl. - apply vsum_unitvector. - + specialize (H0 d3). - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)); intros. - specialize (H7 vin). - specialize (H0 H7 H2). - rewrite eqq2 in H0. - unfold lift in H0. - symmetry in H0. - case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence]. - rewrite H8 in H0; simpl in H0. - specialize (apply vectoro_to_ovector_exists_None H0); intros. - destruct H9. - generalize (backprop_deriv_fully_closed_not_none - σ l d3 - (coerce - (df_eval_backward_gen_top_obligation_2 - UnitAnn (DTVector n) l n eq_refl x0) - (UnitVector n x0))); intros. - specialize (H9 H2). - match_option_in e. - tauto. - + unfold lift in H. - generalize (backprop_deriv_fully_closed_not_none - σ x grad_env 1%R); intros. - specialize (H7 H1). - match_option_in H. - tauto. - - Case "MatrixScalMult"%string. - destruct closed. - specialize (H grad_env vin H1). - simpl in *. - generalize (eval_fully_closed_not_none σ x); intros. - specialize (H3 H1). - match_case; [intros|tauto]. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H5 H2). - case_eq (df_eval σ l); [intros|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - case_eq (df_eval_backprop_deriv σ x grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - match_option; unfold lift; symmetry. - apply vectoro_to_ovector_forall_some_b_strong; intros. - apply vectoro_to_ovector_forall_some_b_strong; intros. - replace - (@msum floatish_R n m - (fun (i1 : @sig nat (fun n' : nat => lt n' n)) - (j : @sig nat (fun m' : nat => lt m' m)) => - Rmult (d1 i1 j) - (@coerce (Matrix R n m) (Matrix R n m) - (@df_eval_backward_gen_top_obligation_3 floatish_R UnitAnn - (DTMatrix n m) (@MatrixScalMult floatish_R UnitAnn n m ann x l) n m - (@eq_refl definition_function_types (DTMatrix n m)) i i0) - (@UnitMatrix floatish_R n m i i0) i1 j))) - with (d1 i i0). - generalize (scalarMult_backprop_grad_scalar - σ x s grad_env grad_env 1%R (d1 i i0)); intros; simpl in H7. - unfold df_eval_backprop_delta in H7. - - rewrite eqq1 in H7; simpl in H7. - specialize (H7 vin vin). - rewrite eqq0 in H7. - generalize (backprop_deriv_fully_closed_not_none - σ x grad_env (d1 i i0 * 1)%R ); intros. - specialize (H8 H1); specialize (H7 H8). - cut_to H7; try discriminate. - invcs H. - { specialize (H0 d3). - cut_to H0; - [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - | apply H2]. - unfold lift in H7; simpl in H7. - match_option_in H7. - replace (d1 i i0) with (d1 i i0 * 1)%R by lra. - rewrite eqq3. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)); intros. - specialize (H vin). - case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence]. - rewrite H9 in H0. - unfold lift in H0. - rewrite eqq2 in H0. - generalize (scalarMult_backprop_grad_scalar σ l s d2 d3 (UnitMatrix n m i i0) d) - ; intros; simpl in H10. - unfold df_eval_backprop_delta in H10; simpl in H10. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat));intros. - specialize (H11 vin). - case_eq (vartlookup d2 (s, DTfloat)); [intros | congruence]. - specialize (H10 H11 H). - rewrite H9, H12 in H10. - symmetry in H0. - specialize (apply vectoro_to_ovector_forall_some_f H0); intros. - specialize (H13 i); simpl in H13. - specialize (apply vectoro_to_ovector_forall_some_f H13); intros. - specialize (H14 i0); simpl in H14. - generalize (backprop_deriv_fully_closed_not_none - σ l d2 - (fun (i1 : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (d * UnitMatrix n m i i0 i1 j)%R)); intros. - specialize (H15 H2); specialize (H10 H15). - generalize (backprop_deriv_fully_closed_not_none - σ l d3 (UnitMatrix n m i i0)); intros. - specialize (H16 H2); specialize (H10 H16). - unfold lift in H10. - match_option_in H10; [|congruence]; f_equal. - replace - (fun (i1 : @sig nat (fun n' : nat => lt n' n)) - (j : @sig nat (fun m' : nat => lt m' m)) => - Rmult - (@coerce (Matrix R n m) (Matrix R n m) - (@df_eval_backward_gen_top_obligation_3 - floatish_R UnitAnn - (DTMatrix n m) (@MatrixScalMult floatish_R UnitAnn n m ann x l) n m - (@eq_refl definition_function_types (DTMatrix n m)) i i0) - (@UnitMatrix floatish_R n m i i0) i1 j) d) - with - (fun i1 j => (d * UnitMatrix n m i i0 i1 j)%R). - rewrite eqq4. - rewrite (split_subvar d2 d7 d0 d6); trivial; f_equal. - match_option_in H10; inversion H10. - match_option_in H14. - rewrite H18. - invcs H7; invcs H14. - destruct i; destruct i0; simpl in eqq6. - rewrite eqq6 in eqq5. - invcs eqq5. - rewrite H17. - lra. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - destruct i; destruct i0; simpl. - lra. - } - + symmetry. - destruct i; destruct i0; simpl. - apply msum_unitmatrix. - + specialize (H0 d3). - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)); intros. - specialize (H7 vin). - specialize (H0 H7 H2). - rewrite eqq2 in H0. - unfold lift in H0. - symmetry in H0. - case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence]. - rewrite H8 in H0; simpl in H0. - specialize (apply vectoro_to_ovector_exists_None H0); intros. - destruct H9. - specialize (apply vectoro_to_ovector_exists_None e); intros. - destruct H9. - generalize (backprop_deriv_fully_closed_not_none - σ l d3 - (coerce - (df_eval_backward_gen_top_obligation_3 - UnitAnn (DTMatrix n m) l n m eq_refl x0 x1) - (UnitMatrix n m x0 x1))); intros. - specialize (H9 H2). - match_option_in e0. - tauto. - + unfold lift in H. - generalize (backprop_deriv_fully_closed_not_none - σ x grad_env 1%R); intros. - specialize (H7 H1). - match_option_in H. - tauto. - - Case "VectorApply"%string. - destruct closed. - generalize (eval_fully_closed_not_none (mk_env_entry (v, DTfloat) (0%R) :: nil) s0); intros. - specialize (H4 H2). - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H5 H3). - match_case; [intros|tauto]. - specialize (H1 grad_env vin H3). - case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto]. - rewrite eqq1 in H1; simpl in H1. - match_option - ; rewrite eqq in H1; unfold lift in H1; symmetry in H1. - + specialize (apply vectoro_to_ovector_forall_some_f H1); intros; simpl in H7. - unfold lift; simpl. - match_option. - * specialize (vectoro_to_ovector_forall_some_f eqq0); intros; simpl in H8. - symmetry. - apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H8 i). - specialize (H7 i). - destruct i. - simpl. - match_nested_case. - -- assert ( df_eval_backprop_deriv σ l grad_env v1 <> None). - apply backprop_deriv_fully_closed_not_none; trivial. - match_option; [|tauto]. - f_equal. - simpl in H7. - match_option_in H8. - specialize (vectoro_to_ovector_forall_some_f eqq2); intros. - assert (H10c := H10). - specialize (H10 (exist _ x l0)). - rewrite vmap_nth in H10; simpl in H10. - match_option_in H10; simpl in H10. - unfold UnitVector in H10; simpl in H10. - destruct (equiv_dec x x); [|congruence]. - invcs H8; rewrite H12. - assert (v1 = - scalarMult (DTVector n) d4 - (UnitVector n (exist (fun n' : nat => n' < n) x l0))). - ++ unfold scalarMult; simpl. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H10c x0); simpl in H10c. - rewrite vmap_nth in H10c; simpl in H10c. - match_option_in H10c. - invcs H10c. - destruct x0. - unfold UnitVector; simpl. - destruct (equiv_dec x0 x). - ** red in e0. - subst. - erewrite index_pf_irrel in eqq6. - rewrite eqq5 in eqq6. - invcs eqq6; lra. - ** lra. - ++ replace (1 * d4)%R with d4 in H10 by lra. - generalize (scalarMult_backprop_grad_scalar - σ l s grad_env grad_env - (UnitVector n (exist (fun n' : nat => n' < n) x l0)) d4) - ; intros. - simpl in H11; cut_to H11; trivial; try congruence. - ** unfold df_eval_backprop_delta in H11. - rewrite eqq1 in H11. - unfold lift in H11; simpl in H11. - rewrite H8 in eqq3. - unfold scalarMult in eqq3; simpl in eqq3. - match_option_in H7. - rewrite eqq3, eqq6 in H11. - invcs H11; rewrite H14. - invcs H7; rewrite H11. - generalize - (df_eval_deriv_genvar_same - [mk_env_entry (v, DTfloat) - (d (exist (fun n' : nat => n' < n) x l0) )] - s0 v); simpl; intros. - specialize (H7 H H2). - unfold lift, df_eval_deriv_gen_top in H7; simpl in H7. - rewrite <- H12. - unfold mk_genvar_env in eqq4; simpl in eqq4. - rewrite eqq4, eqq5 in H7. - invcs H7; lra. - ** apply backprop_deriv_fully_closed_not_none; trivial. - ** apply backprop_deriv_fully_closed_not_none; trivial. - -- apply vectoro_to_ovector_exists_None in eqq2; destruct eqq2. - rewrite vmap_nth in e; simpl in e. - assert (df_eval [mk_env_entry (v, DTfloat) (d x0)] - (df_deriv s0 (v, DTfloat)) <> None). - apply eval_fully_closed_not_none. - apply fully_closed_deriv; trivial. - match_option_in e; tauto. - * apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0. - assert (df_eval_deriv_genvar [mk_env_entry (v, DTfloat) (d x)] s0 - [mk_env_entry (v, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - match_option_in e; tauto. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - - Case "MatrixApply"%string. - destruct closed. - generalize (eval_fully_closed_not_none (mk_env_entry (v, DTfloat) (0%R) :: nil) s0); intros. - specialize (H4 H2). - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H5 H3). - match_case; [intros|tauto]. - specialize (H1 grad_env vin H3). - case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto]. - rewrite eqq1 in H1; simpl in H1. - match_option - ; rewrite eqq in H1; unfold lift in H1; symmetry in H1. - + specialize (apply vectoro_to_ovector_forall_some_f H1); intros; simpl in H7. - unfold lift; simpl. - match_option. - * specialize (apply vectoro_to_ovector_forall_some_f eqq0); intros; simpl in H8. - symmetry. - apply vectoro_to_ovector_forall_some_b_strong; intros. - apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H8 i). - specialize (H7 i). - match_nested_case. - -- assert ( df_eval_backprop_deriv σ l grad_env m1 <> None). - apply backprop_deriv_fully_closed_not_none; trivial. - match_option; [|tauto]. - f_equal. - simpl in H7. - specialize (vectoro_to_ovector_forall_some_f H8); intros. - specialize (vectoro_to_ovector_forall_some_f H7); intros. - specialize (H10 i0); simpl in H10. - match_option_in H10. - specialize (vectoro_to_ovector_forall_some_f eqq2); intros. - assert (H12c := H12). - specialize (H12 i); simpl in H12. - specialize (vectoro_to_ovector_forall_some_f H12); intros. - specialize (H13 i0); simpl in H13; unfold mmap in H13. - rewrite vmap_nth in H13; simpl in H13. - rewrite vmap_nth in H13; simpl in H13. - destruct i; destruct i0; simpl in H13. - simpl in *. - unfold matrix_zip in H13. - rewrite vmap_nth in H13; simpl in H13. - match_option_in H13; simpl in H13. - unfold UnitMatrix in H13; simpl in H13. - destruct (equiv_dec x x); [|congruence]. - destruct (equiv_dec x0 x0); [|congruence]. - invcs H13; invcs H10. - assert (m1 = - scalarMult (DTMatrix m n) d4 - (UnitMatrix m n - (exist (fun n' : nat => n' < m) x l0) - (exist (fun n' : nat => n' < n) x0 l1))). - ++ unfold scalarMult; simpl. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H12c x1); simpl in H12c. - specialize (vectoro_to_ovector_forall_some_f H12c); intros. - specialize (H10 x2); simpl in H10; unfold mmap in H10. - - rewrite vmap_nth in H10; simpl in H10. - rewrite vmap_nth in H10; simpl in H10. - unfold matrix_zip in H10. - rewrite vmap_nth in H10; simpl in H10. - match_option_in H10. - invcs H10. - destruct x1; destruct x2. - unfold UnitMatrix; simpl. - destruct (equiv_dec x1 x). - ** red in e1; subst. - destruct (equiv_dec x2 x0). - --- red in e1; subst. - rewrite index_pf_irrel with (pf2 := l0) in eqq6. - rewrite index_pf_irrel with (pf1 := l3) (pf2 := l1) in eqq6. - rewrite eqq5 in eqq6. - invcs eqq6; lra. - --- lra. - ** lra. - ++ replace (1 * d4)%R with d4 in H15 by lra. - generalize (scalarMult_backprop_grad_scalar - σ l s grad_env grad_env - (UnitMatrix m n (exist (fun n' : nat => n' < m) x l0) - (exist (fun n' : nat => n' < n) x0 l1)) d4) - ; intros. - simpl in H13; cut_to H13; trivial; try congruence. - ** unfold df_eval_backprop_delta in H13. - rewrite eqq1 in H13. - unfold lift in H13; simpl in H13. - rewrite H10 in eqq3. - unfold scalarMult in eqq3; simpl in eqq3. - specialize (H11 (exist (fun n' : nat => n' < n) x0 l1)). - simpl in H11. - match_option_in H11. - rewrite eqq3, eqq6 in H13. - - generalize - (df_eval_deriv_genvar_same - [mk_env_entry (v, DTfloat) - (d (exist (fun n' : nat => n' < m) x l0) - (exist (fun n' : nat => n' < n) x0 l1))] - s0 v). - simpl; intros. - specialize (H16 H H2). - unfold lift, df_eval_deriv_gen_top in H16; simpl in H16. - unfold mk_genvar_env in eqq4; simpl in eqq4. - rewrite eqq4, eqq5 in H16. - invcs H13; rewrite H18. - invcs H16. - invcs H11; rewrite H15; lra. - ** apply backprop_deriv_fully_closed_not_none; trivial. - ** apply backprop_deriv_fully_closed_not_none; trivial. - -- unfold matrixo_to_omatrix in eqq2. - apply vectoro_to_ovector_exists_None in eqq2; destruct eqq2. - apply vectoro_to_ovector_exists_None in e; destruct e. - unfold mmap in e. - rewrite vmap_nth in e; simpl in e. - rewrite vmap_nth in e; simpl in e. - rewrite matrix_zip_m_n in e. - destruct i; destruct i0. - simpl in e. - assert (df_eval [mk_env_entry (v, DTfloat) - (d x x0)] - (df_deriv s0 (v, DTfloat)) <> None). - apply eval_fully_closed_not_none. - apply fully_closed_deriv; trivial. - match_option_in e; tauto. - * unfold matrixo_to_omatrix in eqq0. - apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0. - apply vectoro_to_ovector_exists_None in e; destruct e. - assert (df_eval_deriv_genvar [mk_env_entry (v, DTfloat) (d x x0)] s0 - [mk_env_entry (v, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - match_option_in e; tauto. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - - Case "VLossfun"%string. - destruct closed. - generalize (eval_fully_closed_not_none - (mk_env_entry (v1, DTfloat) (0%R) :: - (mk_env_entry (v2, DTfloat) (0%R) :: nil)) s0); intros. - specialize (H4 H2). - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H5 H3). - match_case; [intros|tauto]. - specialize (H1 grad_env vin H3). - case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto]. - rewrite eqq1 in H1; simpl in H1. - match_option - ; rewrite eqq in H1; unfold lift in H1; symmetry in H1. - + specialize (vectoro_to_ovector_forall_some_f H1); intros; simpl in H7. - unfold lift; simpl. - match_nested_case. - * specialize (vectoro_to_ovector_forall_some_f eqq0); intros; simpl in H8. - symmetry. - match_nested_case. - -- specialize (vectoro_to_ovector_forall_some_f eqq2); intros; simpl in H9. - assert ( df_eval_backprop_deriv σ l grad_env v0 <> None). - apply backprop_deriv_fully_closed_not_none; trivial. - match_option; [|tauto]. - unfold snd in d1; simpl in d1. - assert (forall i : {n' : nat | n' < n}, - df_eval [mk_env_entry (v1, DTfloat) (d i); mk_env_entry (v2, DTfloat) (r i)] - (df_deriv s0 (v1, DTfloat)) - = Some (v0 i)). - intros ; specialize (H9 i) - ; rewrite vmap_nth in H9 - ; simpl in H9 - ; match_option_in H9 - ;inversion H9 ;f_equal; lra. - generalize (df_eval_backprop_delta_by_unit_partvec σ l s grad_env v0 v). - intros. - specialize (H12 H3 vin). - cut_to H12. - ++ unfold df_eval_backprop_delta, lift in H12. - now rewrite eqq1, eqq3 in H12. - ++ intros. - replace (@scaleUnitVector (@float floatish_R) n i (v0 i) (IZR Z0)) with - (scalarMult (DTVector n) (v0 i) (UnitVector n i)) by - (unfold scalarMult, scaleUnitVector, UnitVector; - apply FunctionalExtensionality.functional_extensionality; intros; - simpl; destruct (equiv_dec (` x) (` i)); lra). - rewrite scalarMult_backprop_grad_scalar with (grad_env2 := grad_env); trivial. - ** unfold df_eval_backprop_delta. - rewrite eqq1. - specialize (H7 i); simpl in H7. - specialize (H8 i); simpl in H8. - specialize (H9 i); simpl in H9. - generalize - (df_eval_deriv_genvar_same - [mk_env_entry (v1, DTfloat) (d i); mk_env_entry (v2, DTfloat) (r i)] - s0 v1); simpl; intros. - specialize (H13 H H2). - unfold df_eval_deriv_gen_top in H13; simpl in H13. - unfold lift in H13; simpl in H13. - match_option_in H13. - --- unfold mk_genvar_env in H8; simpl in H8. - rewrite eqq4 in H8. - specialize (H11 i). - destruct i; simpl in H7. - match_option_in H7. - unfold lift; f_equal. - unfold scalarMult; simpl. - invcs H8. - invcs H7. - rewrite H14. - rewrite <- H13 in H11. - invcs H11. - lra. - --- assert (df_eval_deriv_genvar - [mk_env_entry (v1, DTfloat) (d i); - mk_env_entry (v2, DTfloat) (r i)] s0 - [mk_env_entry (v1, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - ** apply backprop_deriv_fully_closed_not_none; trivial. - ** apply backprop_deriv_fully_closed_not_none; trivial. - -- apply vectoro_to_ovector_exists_None in eqq2. - destruct eqq2. - rewrite vmap_nth in e; simpl in e. - match_option_in e. - assert (df_eval [mk_env_entry (v1, DTfloat) (d x); - mk_env_entry (v2, DTfloat) (r x)] - (df_deriv s0 (v1, DTfloat)) <> None). - apply eval_fully_closed_not_none; simpl. - apply fully_closed_deriv; trivial. - tauto. - * apply vectoro_to_ovector_exists_None in eqq0. - destruct eqq0. - match_option_in e. - assert (df_eval_deriv_genvar - [mk_env_entry (v1, DTfloat) (d x); - mk_env_entry (v2, DTfloat) (r x)] s0 - [mk_env_entry (v1, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - - Case "MLossfun"%string. - destruct closed. - generalize (eval_fully_closed_not_none - (mk_env_entry (v1, DTfloat) (0%R) :: - (mk_env_entry (v2, DTfloat) (0%R) :: nil)) s0); intros. - specialize (H4 H2). - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H5 H3). - match_case; [intros|tauto]. - specialize (H1 grad_env vin H3). - case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto]. - rewrite eqq1 in H1; simpl in H1. - match_option - ; rewrite eqq in H1; unfold lift in H1; symmetry in H1. - + specialize (vectoro_to_ovector_forall_some_f H1); intros; simpl in H7. - unfold lift; simpl. - match_nested_case. - * specialize (vectoro_to_ovector_forall_some_f eqq0); intros; simpl in H8. - symmetry. - match_nested_case. - -- specialize (vectoro_to_ovector_forall_some_f eqq2); intros; simpl in H9. - assert ( df_eval_backprop_deriv σ l grad_env m1 <> None). - apply backprop_deriv_fully_closed_not_none; trivial. - match_option; [|tauto]. - unfold snd in d1; simpl in d1. - assert (forall (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}), - match - df_eval [mk_env_entry (v1, DTfloat) (d i j); - mk_env_entry (v2, DTfloat) (r i j)] - (df_deriv s0 (v1, DTfloat)) - with - | Some se => Some (1 * se / IZR (Z.of_nat n))%R - | None => None - end = Some (m1 i j)); intros. - ++ specialize (H9 i). - specialize (vectoro_to_ovector_forall_some_f H9); intros. - specialize (H11 j); simpl in H11; unfold mmap in H11. - do 2 rewrite vmap_nth in H11; simpl in H11. - unfold matrix_zip, vector_zip in H11; simpl in H11. - rewrite vmap_nth in H11; simpl in H11. - match_option; rewrite eqq4 in H11; apply H11. - ++ generalize (df_eval_backprop_delta_by_unit_partmat - σ l s grad_env m1 - (mmap (fun u => u / IZR (Z.of_nat n))%R m0)); intros. - specialize (H12 H3 vin). - cut_to H12. - ** unfold df_eval_backprop_delta, lift in H12. - rewrite eqq1, eqq3 in H12. - replace ((@msum floatish_R m n m0) / IZR (Z.of_nat n))%R with - (msum (mmap (fun u : R => (u / IZR (Z.of_nat n))%R) m0)). - --- apply H12. - --- now rewrite msum_mmap_div_denom. - ** intros. - rewrite scalarMult_backprop_grad_scalar with (grad_env2 := grad_env) - ; trivial. - --- unfold df_eval_backprop_delta. - rewrite eqq1. - specialize (H7 i); simpl in H7. - specialize (H8 i); simpl in H8. - specialize (H11 i j); simpl in H11. - specialize (vectoro_to_ovector_forall_some_f H7); intros. - specialize (H13 j); simpl in H13. - specialize (vectoro_to_ovector_forall_some_f H8); intros. - specialize (H14 j); simpl in H14. - generalize - (df_eval_deriv_genvar_same - [mk_env_entry (v1, DTfloat) (d i j); - mk_env_entry (v2, DTfloat) (r i j)] - s0 v1); simpl; intros. - specialize (H15 H H2). - unfold df_eval_deriv_gen_top in H15; simpl in H15. - unfold lift in H15; simpl in H15. - match_option_in H15. - +++ unfold mk_genvar_env in H14; simpl in H14. - rewrite eqq4 in H14. - destruct i; destruct j; simpl in H13. - match_option_in H13. - unfold lift; f_equal. - invcs H13. - invcs H14. - rewrite H17. - unfold mmap. - rewrite vmap_nth. - rewrite vmap_nth. - rewrite <- H16. - rewrite <- H15 in H11. - invcs H11. - lra. - +++ assert (df_eval_deriv_genvar - [mk_env_entry (v1, DTfloat) (d i j); - mk_env_entry (v2, DTfloat) (r i j)] s0 - [mk_env_entry (v1, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - --- apply backprop_deriv_fully_closed_not_none; trivial. - --- apply backprop_deriv_fully_closed_not_none; trivial. - -- unfold matrixo_to_omatrix in eqq2. - apply vectoro_to_ovector_exists_None in eqq2. - destruct eqq2. - apply vectoro_to_ovector_exists_None in e. - destruct e. - unfold mmap in e. - rewrite vmap_nth in e; simpl in e. - rewrite vmap_nth in e; simpl in e. - unfold matrix_zip,vector_zip in e; simpl in e. - rewrite vmap_nth in e. - match_option_in e. - assert (df_eval [mk_env_entry (v1, DTfloat) (d x x0); - mk_env_entry (v2, DTfloat) (r x x0)] - (df_deriv s0 (v1, DTfloat)) <> None). - apply eval_fully_closed_not_none; simpl. - apply fully_closed_deriv; trivial. - tauto. - * unfold matrixo_to_omatrix in eqq0. - apply vectoro_to_ovector_exists_None in eqq0. - destruct eqq0. - apply vectoro_to_ovector_exists_None in e. - destruct e. - match_option_in e. - assert (df_eval_deriv_genvar - [mk_env_entry (v1, DTfloat) (d x x0); - mk_env_entry (v2, DTfloat) (r x x0)] s0 - [mk_env_entry (v1, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - Qed. - -(* -Tactic Notation "DefinedFunction_cases" tactic(first) ident(c) := - first; - [ Case_aux c "Number"%string - | Case_aux c "Constant"%string - | Case_aux c "DVector"%string - | Case_aux c "DMatrix"%string - | Case_aux c "Var"%string - | Case_aux c "Plus"%string - | Case_aux c "Minus"%string - | Case_aux c "Times"%string - | Case_aux c "Divide"%string - | Case_aux c "Square"%string - | Case_aux c "Exp"%string - | Case_aux c "Log"%string - | Case_aux c "Abs"%string - | Case_aux c "Sign"%string - | Case_aux c "PSign"%string - | Case_aux c "Max"%string - | Case_aux c "VectorDot"%string - | Case_aux c "VectorSum"%string - | Case_aux c "MatrixSum"%string - | Case_aux c "VectorElem"%string - | Case_aux c "MatrixElem"%string - | Case_aux c "MatrixVectorMult"%string - | Case_aux c "MatrixVectorAdd"%string - | Case_aux c "MatrixMult"%string - | Case_aux c "VectorPlus"%string - | Case_aux c "VectorMinus"%string - | Case_aux c "MatrixPlus"%string - | Case_aux c "MatrixMinus"%string - | Case_aux c "VectorScalMult"%string - | Case_aux c "MatrixScalMult"%string - | Case_aux c "VectorApply"%string - | Case_aux c "MatrixApply"%string - | Case_aux c "VLossfun"%string - | Case_aux c "MLossfun"%string]. - - - Lemma tree_backpropeq_complete_gen {T} (env gradenv : df_env) - (dfexpr : DefinedFunction EvalAnn T) (grad : definition_function_types_interp T) : - forall (x : SubVar), - let xvar := (x, DTfloat) in - vartlookup gradenv (x,DTfloat) <> None -> - match df_eval_tree_deriv env dfexpr xvar, - backprop_lookup (Some gradenv) xvar, - backprop_lookup (df_eval_tree_backprop_deriv env dfexpr gradenv grad) xvar - with - | Some dval, Some bval0, Some bval1 => (dval*grad + bval0)%R = bval1 - | None, _, None => True - | _, _, _ => False - end. - Proof. - simpl. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case. - - - Case "Number"%string. - intros _ _ grad gradenv xinn. - destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra. - - Case "Constant"%string. - intros _ _ grad gradenv xinn. - destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra. - - Case "Var"%string. - intros sv _ grad gradenv xinn. - case_eq (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]. - destruct (var_dec x sv); simpl. - + subst. - rewrite H; simpl. - rewrite lookup_update. - destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (sv, DTfloat)); [| congruence]. - unfold addvar; simpl. - rewrite H. - lra. - + destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (x, DTfloat)); [congruence | ]. - case_eq (vartlookup gradenv (sv, DTfloat)); simpl; intros. - * rewrite lookup_update_neq by congruence. - rewrite H. - lra. - * rewrite H. - lra. - - Case "Plus"%string. - intros _ l r IHl IHr grad gradenv xinn. - case_eq (df_eval_tree_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_tree_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl grad gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env l gradenv grad) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr grad ge'). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' grad) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_tree_backprop_deriv env l gradenv grad ); simpl; trivial; intros. - apply IHr. - apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl grad gradenv xinn). - case_eq (df_eval_tree_backprop_deriv env l gradenv grad); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Minus"%string. - intros _ l r IHl IHr grad gradenv xinn. - case_eq (df_eval_tree_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_tree_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl grad gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env l gradenv grad) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr (- grad)%R ge'). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (- grad)%R) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_tree_backprop_deriv env l gradenv grad ); simpl; trivial; intros. - apply IHr. - apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl grad gradenv xinn). - case_eq (df_eval_tree_backprop_deriv env l gradenv grad); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Times"%string. - intros _ l r IHl IHr grad gradenv xinn. - case_eq (df_eval_tree_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_tree_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl (get_annotation r * grad)%R gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr (get_annotation l * grad)%R ge'). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (get_annotation l * grad)%R) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R ); simpl; trivial; intros. - apply IHr. - apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl (get_annotation r * grad)%R gradenv xinn). - case_eq (df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Divide"%string. - intros _ l r IHl IHr grad gradenv xinn. - case_eq (df_eval_tree_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_tree_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl (grad / get_annotation r)%R gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr (- get_annotation l / ((get_annotation r) * (get_annotation r)) * grad)%R ge'). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (- get_annotation l / ((get_annotation r) * (get_annotation r)) * grad)%R) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r)%R ); simpl; trivial; intros. - apply IHr. - apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl (grad / get_annotation r)%R gradenv xinn). - case_eq (df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r )%R); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Square"%string. - intros _ e IHe grad gradenv xinn. - - specialize (IHe (2 * (get_annotation e) * grad)%R gradenv xinn). - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env e gradenv (2 * (get_annotation e) * grad)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - Case "Exp"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe (grad * exp (get_annotation e))%R gradenv xinn). - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env e gradenv (grad * exp (get_annotation e))%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - - Case "Log"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe (grad / get_annotation e)%R gradenv xinn). - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env e gradenv (grad / get_annotation e)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - - Case "Abs"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe (grad * (sign (get_annotation e)))%R gradenv xinn). - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env e gradenv (grad * (sign (get_annotation e)))%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - Case "Sign"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe 0%R gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (df_eval_tree_backprop_deriv env e gradenv 0%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - simpl. - replace (de * 0)%R with (0)%R in IHe by lra. - replace (0 * grad)%R with (0)%R by lra. - apply IHe. - - Case "PSign"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe 0%R gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (df_eval_tree_backprop_deriv env e gradenv 0%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - simpl. - replace (de * 0)%R with (0)%R in IHe by lra. - replace (0 * grad)%R with (0)%R by lra. - apply IHe. - - Case "Max"%string. - intros _ l r IHl IHr grad gradenv xinn. - specialize (IHl grad gradenv xinn). - specialize (IHr grad gradenv xinn). - destruct (Rle_dec (get_annotation l) (get_annotation r)); simpl. - destruct (df_eval_tree_deriv env r (x, DTfloat)); simpl; trivial. - destruct (df_eval_tree_deriv env l (x, DTfloat)); simpl; trivial. - Qed. -*) - -End real_pfs. - -(* - Section FreeVariablesExample. - (* We need to open the string scope in order to use "a" as a string. *) - Open Scope string_scope. - Theorem ex1 : (df_free_variables (Plus (Var "a") (Var "b"))) = "a"::"b"::nil. - Proof. - (* Reflexivity doesn't need syntactically identical things on either side of =. - * It suffices that the left-hand side beta-reduced to the right-hand side. *) - reflexivity. - Qed. - - Close Scope string_scope. - - End FreeVariablesExample. -*) diff --git a/coq/NeuralNetworks/Gen_NN.v b/coq/NeuralNetworks/Gen_NN.v deleted file mode 100644 index b8d25204..00000000 --- a/coq/NeuralNetworks/Gen_NN.v +++ /dev/null @@ -1,827 +0,0 @@ -Require Import String. -Require Import RelationClasses. -Require Import EquivDec. -Require Import Streams. -Require Import List. -Require Import ListAdd. -Require Import Rbase Rtrigo Rpower Rbasic_fun. -Require Import DefinedFunctions. -Require Import Vector. -Require Import Lra Lia. - -Require Import Floatish Utils. - -Section GenNN. - - Context {floatish_impl:floatish}. - - Local Open Scope float. - - Definition UnitDefinedFunction := @DefinedFunction floatish_impl UnitAnn. - - Record FullNN : Type := mkNN { ldims : list nat; param_var : SubVar; - f_activ : UnitDefinedFunction DTfloat; f_loss : UnitDefinedFunction DTfloat }. - - Definition mkSubVarVector (v : SubVar) (n : nat) : UnitDefinedFunction (DTVector n) := - DVector tt (fun i => Var (Sub v (proj1_sig i), DTfloat) tt). - - Definition mkVarVector (v : SubVar) (n : nat) : UnitDefinedFunction (DTVector n) := - Var (v, DTVector n) tt. - - Definition mkSubVarMatrix (v : SubVar) (n m : nat) : UnitDefinedFunction (DTMatrix n m) := - DMatrix tt (fun i j => Var (Sub (Sub v (proj1_sig i)) (proj1_sig j), DTfloat) tt). - - Definition mkVarMatrix (v : SubVar) (n m : nat) : UnitDefinedFunction (DTMatrix n m) := - Var (v, DTMatrix n m) tt. - - Definition unique_var {Ann} (df : DefinedFunction Ann DTfloat) : option var_type := - let fv := nodup vart_dec (df_free_variables df) in - match fv with - | nil => None - | v :: nil => Some v - | _ => None - end. - - Definition activation (df : UnitDefinedFunction DTfloat) (vec : list (UnitDefinedFunction DTfloat)) : option (list (UnitDefinedFunction DTfloat)) := - match unique_var df with - | Some (v, DTfloat) => Some (map (fun dfj => df_subst df (v, DTfloat) dfj) vec) - | _ => None - end. - - Definition create_activation_fun (df : UnitDefinedFunction DTfloat) : option (UnitDefinedFunction DTfloat -> UnitDefinedFunction DTfloat) := - match unique_var df with - | Some (v, DTfloat) => Some (fun val => df_subst df (v, DTfloat) val) - | _ => None - end. - - Definition mkNN2 (n1 n2 n3 : nat) (ivar wvar : SubVar) (f_activ : UnitDefinedFunction DTfloat) (f_activ_var : SubVar) : (UnitDefinedFunction (DTVector n3)) := - let mat1 := mkSubVarMatrix (Sub wvar 1) n2 n1 in - let mat2 := mkSubVarMatrix (Sub wvar 2) n3 n2 in - let ivec := mkVarVector ivar n1 in - let N1 := VectorApply tt f_activ_var f_activ (MatrixVectorMult tt mat1 ivec) in - VectorApply tt f_activ_var f_activ (MatrixVectorMult tt mat2 N1). - - Definition mkVarNN2 (n1 n2 n3 : nat) (ivar wvar : SubVar) (f_activ : UnitDefinedFunction DTfloat) (f_activ_var : SubVar) : (UnitDefinedFunction (DTVector n3) * list var_type) := - let mat1 := (Sub wvar 1, DTMatrix n2 n1) in - let mat2 := (Sub wvar 2, DTMatrix n3 n2) in - let ivec := Var (ivar, DTVector n1) tt in - let N1 := VectorApply tt f_activ_var f_activ (MatrixVectorMult tt (Var mat1 tt) ivec) in - (VectorApply tt f_activ_var f_activ (MatrixVectorMult tt (Var mat2 tt) N1), - mat1 :: mat2 :: nil). - - Definition mkVarMatNN2 (n1 n2 n3 nsamp: nat) (ivar wvar : SubVar) (f_activ : UnitDefinedFunction DTfloat) (f_activ_var : SubVar) : (UnitDefinedFunction (DTMatrix n3 nsamp) * list var_type) := - let mat1 := (Sub wvar 1, DTMatrix n2 n1) in - let mat2 := (Sub wvar 2, DTMatrix n3 n2) in - let imat := Var (ivar, DTMatrix n1 nsamp) tt in - let N1 := MatrixApply tt f_activ_var f_activ (MatrixMult tt (Var mat1 tt) imat) in - (MatrixApply tt f_activ_var f_activ (MatrixMult tt (Var mat2 tt) N1), - mat1 :: mat2 :: nil). - - Definition mkNN_bias_step (n1 n2 : nat) (ivec : UnitDefinedFunction (DTVector n1)) - (mat : UnitDefinedFunction (DTMatrix n2 n1)) - (bias : UnitDefinedFunction (DTVector n2)) - (f_activ_var : SubVar) (f_activ : UnitDefinedFunction DTfloat) - : UnitDefinedFunction (DTVector n2) := - VectorApply tt f_activ_var f_activ (VectorPlus tt (MatrixVectorMult tt mat ivec) bias). - - Definition mkNN2_bias (n1 n2 n3 : nat) (ivar wvar : SubVar) (f_activ : UnitDefinedFunction DTfloat) (f_activ_var : SubVar) : UnitDefinedFunction (DTVector n3) := - let mat1 := mkSubVarMatrix (Sub wvar 1) n2 n1 in - let b1 := mkSubVarVector (Sub wvar 1) n2 in - let mat2 := mkSubVarMatrix (Sub wvar 2) n3 n2 in - let b2 := mkSubVarVector (Sub wvar 2) n3 in - let ivec := mkVarVector ivar n1 in - let N1 := mkNN_bias_step n1 n2 ivec mat1 b1 f_activ_var f_activ in - mkNN_bias_step n2 n3 N1 mat2 b2 f_activ_var f_activ. - - Definition mkNN2_Var_bias (n1 n2 n3 : nat) (ivar wvar bvar : SubVar) (f_activ : UnitDefinedFunction DTfloat) (f_activ_var : SubVar) : (UnitDefinedFunction (DTVector n3) * list var_type) := - let mat1 := (Sub wvar 1, DTMatrix n2 n1) in - let b1 := (Sub bvar 1, DTVector n2) in - let mat2 := (Sub wvar 2, DTMatrix n3 n2) in - let b2 := (Sub bvar 2, DTVector n3) in - let ivec := Var (ivar, DTVector n1) tt in - let N1 := mkNN_bias_step n1 n2 ivec (Var mat1 tt) (Var b1 tt) f_activ_var f_activ in - (mkNN_bias_step n2 n3 N1 (Var mat2 tt) (Var b2 tt) f_activ_var f_activ, - mat1 :: b1 :: mat2 :: b2 :: nil). - - Definition mkNN_Mat_bias_step (n1 n2 nsamp : nat) (imat : UnitDefinedFunction (DTMatrix n1 nsamp)) - (mat : UnitDefinedFunction (DTMatrix n2 n1)) - (bias : UnitDefinedFunction (DTVector n2)) - (f_activ_var : SubVar) (f_activ : UnitDefinedFunction DTfloat) - : UnitDefinedFunction (DTMatrix n2 nsamp) := - MatrixApply tt f_activ_var f_activ (MatrixVectorAdd tt (MatrixMult tt mat imat) bias). - - Definition mkNN2_Var_Mat_bias (n1 n2 n3 nsamp: nat) (ivar wvar bvar: SubVar) (f1_activ f2_activ : UnitDefinedFunction DTfloat) (f1_activ_var f2_activ_var : SubVar) : (UnitDefinedFunction (DTMatrix n3 nsamp) * list var_type) := - let mat1 := (Sub wvar 1, DTMatrix n2 n1) in - let b1 := (Sub bvar 1, DTVector n2) in - let mat2 := (Sub wvar 2, DTMatrix n3 n2) in - let b2 := (Sub bvar 2, DTVector n3) in - let imat := Var (ivar, DTMatrix n1 nsamp) tt in - let N1 := mkNN_Mat_bias_step n1 n2 nsamp imat (Var mat1 tt) (Var b1 tt) f1_activ_var f1_activ in - (mkNN_Mat_bias_step n2 n3 nsamp N1 (Var mat2 tt) (Var b2 tt) f2_activ_var f2_activ, - mat1 :: b1 :: mat2 :: b2 :: nil). - - Lemma vector_float_map_last_rewrite {B nvlist1 n2 v n1} : - (DTVector (last ((@domain _ B) nvlist1) n2)) = - (DTVector (last (domain((n2, v) :: nvlist1)) n1)). - Proof. - rewrite domain_cons. - rewrite last_cons. - reflexivity. - Qed. - - Lemma matrix_float_map_last_rewrite {B nvlist1 n2 v n1 nsamp} : - (DTMatrix (last ((@domain _ B) nvlist1) n2) nsamp) = - (DTMatrix (last (domain((n2, v) :: nvlist1)) n1) nsamp). - Proof. - rewrite domain_cons. - rewrite last_cons. - reflexivity. - Qed. - - Fixpoint mkNN_gen_0 (n1:nat) (nvlist : list (nat * SubVar)) - (ivec : (UnitDefinedFunction (DTVector n1))) - (f_activ_var : SubVar ) (f_activ : UnitDefinedFunction DTfloat) : - UnitDefinedFunction (DTVector (last (domain nvlist) n1)) -:= - match nvlist with - | nil => ivec - | cons (n2,v) nvlist1 => - let mat := mkSubVarMatrix v n2 n1 in - let b := mkSubVarVector v n2 in - let N := mkNN_bias_step n1 n2 ivec mat b f_activ_var f_activ in - eq_rect _ UnitDefinedFunction (mkNN_gen_0 n2 nvlist1 N f_activ_var f_activ) _ vector_float_map_last_rewrite - end. - - Fixpoint mkNN_Var_gen_0 (n1:nat) (nvlist : list (nat * SubVar)) - (ivec : (UnitDefinedFunction (DTVector n1))) - (f_activ_var : SubVar ) (f_activ : UnitDefinedFunction DTfloat) : - UnitDefinedFunction (DTVector (last (domain nvlist) n1)) -:= - match nvlist with - | nil => ivec - | cons (n2,v) nvlist1 => - let mat := mkVarMatrix v n2 n1 in - let b := mkVarVector v n2 in - let N := mkNN_bias_step n1 n2 ivec mat b f_activ_var f_activ in - eq_rect _ UnitDefinedFunction (mkNN_Var_gen_0 n2 nvlist1 N f_activ_var f_activ) _ vector_float_map_last_rewrite - end. - - Fixpoint mkNN_Mat_Var_gen_0 (nsamp n1:nat) (nvlist : list (nat * SubVar)) - (imat : (UnitDefinedFunction (DTMatrix n1 nsamp))) - (f_activ_var : SubVar ) (f_activ : UnitDefinedFunction DTfloat) : - UnitDefinedFunction (DTMatrix (last (domain nvlist) n1) nsamp) -:= - match nvlist with - | nil => imat - | cons (n2,v) nvlist1 => - let mat := mkVarMatrix v n2 n1 in - let b := mkVarVector v n2 in - let N := mkNN_Mat_bias_step n1 n2 nsamp imat mat b f_activ_var f_activ in - eq_rect _ UnitDefinedFunction (mkNN_Mat_Var_gen_0 nsamp n2 nvlist1 N f_activ_var f_activ) _ matrix_float_map_last_rewrite - end. - - Program Definition mkNN_gen (n1:nat) (nlist : list nat) (ivar wvar f_activ_var : SubVar) - (f_activ : UnitDefinedFunction DTfloat) : - UnitDefinedFunction (DTVector (last nlist n1)) := - let vlist := map (fun i => Sub wvar i) (seq 1 (length nlist)) in - let ivec := mkVarVector ivar n1 in - eq_rect _ UnitDefinedFunction - (mkNN_gen_0 n1 (combine nlist vlist) ivec f_activ_var f_activ) _ _. - Next Obligation. - f_equal. - f_equal. - rewrite combine_domain_eq; trivial. - now rewrite map_length, seq_length. - Qed. - - Program Definition mkNN_Var_gen (n1:nat) (nlist : list nat) (ivar wvar f_activ_var : SubVar) - (f_activ : UnitDefinedFunction DTfloat) : - UnitDefinedFunction (DTVector (last nlist n1)) := - let vlist := map (fun i => Sub wvar i) (seq 1 (length nlist)) in - let ivec := mkVarVector ivar n1 in - eq_rect _ UnitDefinedFunction - (mkNN_Var_gen_0 n1 (combine nlist vlist) ivec f_activ_var f_activ) _ _. - Next Obligation. - f_equal. - f_equal. - rewrite combine_domain_eq; trivial. - now rewrite map_length, seq_length. - Qed. - - Program Definition mkNN_Mat_Var_gen (nsamp n1:nat) (nlist : list nat) (ivar wvar f_activ_var : SubVar) - (f_activ : UnitDefinedFunction DTfloat) : - UnitDefinedFunction (DTMatrix (last nlist n1) nsamp) := - let vlist := map (fun i => Sub wvar i) (seq 1 (length nlist)) in - let imat := mkVarMatrix ivar n1 nsamp in - eq_rect _ UnitDefinedFunction - (mkNN_Mat_Var_gen_0 nsamp n1 (combine nlist vlist) imat f_activ_var f_activ) _ _. - Next Obligation. - f_equal. - f_equal. - rewrite combine_domain_eq; trivial. - now rewrite map_length, seq_length. - Qed. - - Definition softmax {n:nat} (NN : UnitDefinedFunction (DTVector n)) : UnitDefinedFunction (DTVector n) := - let expvar := Name "expvar" in - let NNexp := VectorApply tt expvar (Exp tt (Var (expvar, DTfloat) tt)) NN in - let NNexpscale := Divide tt (Number tt 1) (VectorSum tt NNexp) in - VectorScalMult tt NNexpscale NNexp. - - Definition L2loss (nnvar ovar : SubVar) : UnitDefinedFunction DTfloat := - Square tt ( Minus tt (Var (nnvar, DTfloat) tt) (Var (ovar, DTfloat) tt) ). - - Definition L1loss (nnvar ovar : SubVar) : UnitDefinedFunction DTfloat := - Abs tt (Minus tt (Var (nnvar, DTfloat) tt) (Var (ovar, DTfloat) tt)). - - Definition Sigmoid (var : SubVar) : UnitDefinedFunction DTfloat := - Divide tt (Number tt 1) - (Plus tt (Number tt 1) (Exp tt (Minus tt (Number tt 0) (Var (var, DTfloat) tt)))). - - Definition CrossEntropy (nnvar ovar : SubVar) : UnitDefinedFunction DTfloat := - let nnvar' := Var (nnvar, DTfloat) tt in - let ovar' := Var (ovar, DTfloat) tt in - Minus tt (Times tt (Minus tt ovar' (Number tt 1)) (Log tt (Minus tt (Number tt 1) nnvar'))) - (Times tt ovar' (Log tt nnvar')). - - Record testcases : Type := mkTest {ninput: nat; noutput: nat; ntest: nat; - datavec : Vector ((Vector float ninput) * (Vector float noutput)) ntest}. - - Definition NNinstance1samp {ninput noutput : nat} (ivar : SubVar) - (f_loss : UnitDefinedFunction DTfloat) - (f_loss_NNvar f_loss_outvar : SubVar) - (NN : UnitDefinedFunction (DTVector noutput)) (σ:df_env) - (data: (Vector float ninput) * (Vector float noutput)) - : df_env * (UnitDefinedFunction DTfloat) := - let ipair := mk_env_entry (ivar, DTVector ninput) (fst data) in - (cons ipair σ, VLossfun tt f_loss_NNvar f_loss_outvar f_loss NN (snd data)). - - Definition NNinstancebatch {ninput nsamp noutput : nat} (ivar : SubVar) - (f_loss : UnitDefinedFunction DTfloat) - (f_loss_NNvar f_loss_outvar : SubVar) - (NN : UnitDefinedFunction (DTMatrix noutput nsamp)) (σ:df_env) - (data: (Matrix float ninput nsamp) * (Matrix float noutput nsamp)) - : df_env * (UnitDefinedFunction DTfloat) := - let ipair := mk_env_entry (ivar, DTMatrix ninput nsamp) (fst data) in - (cons ipair σ, MLossfun tt f_loss_NNvar f_loss_outvar f_loss NN (snd data)). - - Definition EvalNNinstance1samp {ninput noutput : nat} (ivar : SubVar) - (f_loss : UnitDefinedFunction DTfloat) - (f_loss_NNvar f_loss_outvar : SubVar) - (NN : UnitDefinedFunction (DTVector noutput)) (σ:df_env) - (data: (Vector float ninput) * (Vector float noutput)) - : option float := - let ipair := mk_env_entry (ivar, DTVector ninput) (fst data) in - df_eval (cons ipair σ) (VLossfun tt f_loss_NNvar f_loss_outvar f_loss NN (snd data)). - - Definition EvalNNinstancebatch {ninput nsamp noutput : nat} (ivar : SubVar) - (f_loss : UnitDefinedFunction DTfloat) - (f_loss_NNvar f_loss_outvar : SubVar) - (NN : UnitDefinedFunction (DTMatrix noutput nsamp)) (σ:df_env) - (data: (Matrix float ninput nsamp) * (Matrix float noutput nsamp)) - : option float := - let ipair := mk_env_entry (ivar, DTMatrix ninput nsamp) (fst data) in - df_eval (cons ipair σ) (MLossfun tt f_loss_NNvar f_loss_outvar f_loss NN (snd data)). - - (* - Lemma NNinstance_unique_var (n1 n2 n3 : nat) (ivar : SubVar) (f_loss : DefinedFunction DTfloat) - (NN2 : DefinedFunction (DTVector n3)) (inputs : (list float)) - (outputs : Vector float n3) (v:SubVar) : - unique_var f_loss = Some v -> - NNinstance n1 n2 n3 ivar f_loss NN2 inputs outputs = - Some ( - let ipairs := (list_prod (map (fun n => (Sub ivar n)) (seq 1 n1)) - (map Number inputs)) in - let losses := VectorMinus (df_subst_list NN2 ipairs) - (DVector (vmap Number outputs)) in - (VectorSum (VectorApply v f_loss losses)) - ). - Proof. - unfold NNinstance. - intros. - rewrite H. - reflexivity. - Qed. - - Lemma NNinstance_None (n1 n2 n3 : nat) (ivar : SubVar) (f_loss : DefinedFunction DTfloat) - (NN2 : DefinedFunction (DTVector n3)) (inputs : (list float)) - (outputs : Vector float n3) : - unique_var f_loss = None -> - NNinstance n1 n2 n3 ivar f_loss NN2 inputs outputs = None. - Proof. - unfold NNinstance. - intros. - now rewrite H. - Qed. - *) - -(* Local Existing Instance floatish_interval. *) - - Definition lookup_list (σ:df_env) (lvar : list SubVar) : option (list float) := - listo_to_olist (map (fun v => (vartlookup σ (v, DTfloat)):option float) lvar). - - Definition combine_with {A:Type} {B:Type} {C:Type} (f: A -> B -> C ) - (lA : list A) (lB : list B) : list C := - map (fun '(a, b) => f a b) (combine lA lB). - - Definition combine3_with {A:Type} {B:Type} {C:Type} {D:Type} (f: A -> B -> C -> D) - (lA : list A) (lB : list B) (lC : list C) : list D := - map (fun '(a, bc) => f a (fst bc) (snd bc)) (combine lA (combine lB lC)). - - Fixpoint streamtake (n : nat) {A : Type} (st : Stream A) : (list A) * (Stream A) := - match n with - | 0 => (nil, st) - | S n' => let rst := streamtake n' (Streams.tl st) in - ((Streams.hd st)::(fst rst), snd rst) - end. - - Lemma streamtake_n (n : nat) (A : Type) (st : Stream A) : - length (fst (streamtake n st)) = n. - Proof. - generalize st. - induction n. - reflexivity. - intros. - simpl. - f_equal. - specialize IHn with (st := Streams.tl st0). - apply IHn. - Qed. - - Fixpoint env_update_first (l:df_env) (an:env_entry_type) : df_env - := match l with - | nil => nil - | fv::os => if (projT1 an) == (projT1 fv) then an::os - else fv::(env_update_first os an) - end. - - Definition env_update_list (l up:df_env) : df_env - := fold_left (env_update_first) up l. - - Definition optimize_step - (step : nat) (df : UnitDefinedFunction DTfloat) (σ:df_env) (lvar : list SubVar) - (noise_st : Stream float) : (option df_env)*(Stream float) := - let lvart:list var_type := (map (fun v => (v, DTfloat)) lvar) in - let ogradvec := df_eval_gradient σ df lvart in - let alpha := 1 / (FfromZ (Z.of_nat (S step))) in - let '(lnoise, nst) := streamtake (length lvar) noise_st in - let olvals := lookup_list σ lvar in - (match (ogradvec, olvals) with - | (Some gradvec, Some lvals) => - Some (env_update_list - σ - (map (fun '(v,e) => mk_env_entry (v, DTfloat) (e:float)) - (combine lvar (combine3_with - (fun val grad noise => val - alpha*(grad + noise)) - lvals gradvec lnoise)))) - | (_, _) => None - end, nst). - - Fixpoint get_noise_vector (n: nat) (noise_st: Stream float) : - (Vector float n) * (Stream float) := - match n with - | 0 => (vnil, noise_st) - | S n' => - let noise := Streams.hd noise_st in - let nst := Streams.tl noise_st in - let '(vec, nst') := get_noise_vector n' nst in - (vcons noise vec, nst') - end. - - Fixpoint get_noise_matrix (n m: nat) (noise_st: Stream float) : - (Matrix float n m) * (Stream float) := - let '(vec, nst) := get_noise_vector m noise_st in - match n with - | 0 => (fun i => vec, nst) - | S n' => - let '(mat, nst') := get_noise_matrix n' m nst in - (vcons vec mat, nst') - end. - - Definition get_noise (t:definition_function_types) (noise_st:Stream float) : - (definition_function_types_interp t) * (Stream float) := - match t with - | DTfloat => (Streams.hd noise_st, Streams.tl noise_st) - | DTVector n => get_noise_vector n noise_st - | DTMatrix m n => get_noise_matrix m n noise_st - end. - - Program Definition update_val_gradenv (grad_env:df_env) (x:var_type) (alpha:float) : - definition_function_types_interp (snd x) -> - definition_function_types_interp (snd x) -> - definition_function_types_interp (snd x) := - (match snd x as y return snd x = y -> - definition_function_types_interp y -> - definition_function_types_interp y -> - definition_function_types_interp y with - | DTfloat => fun pf val noise => - match vartlookup grad_env x with - | Some grad => val - alpha * (coerce _ grad + noise) - | _ => val - end - | DTVector n => fun pf val noise => - match vartlookup grad_env x with - | Some grad => - fun i => - (val i) - alpha * (((coerce _ grad) i) + (noise i)) - | _ => val - end - | DTMatrix m n => fun pf val noise => - match vartlookup grad_env x with - | Some grad => - fun i j => - (val i j) - alpha * (((coerce _ grad) i j) + (noise i j)) - | _ => val - end - end) (eq_refl _) - . - Next Obligation. - rewrite pf; reflexivity. - Qed. - Next Obligation. - rewrite pf; reflexivity. - Qed. - Next Obligation. - rewrite pf; reflexivity. - Qed. - - - - Definition update_entry (entry: env_entry_type) (grad_env:df_env) (alpha:float) - (noise_st : Stream float) : (env_entry_type*Stream float) := - let x := projT1 entry in - let val := projT2 entry in - let '(noise, nst) := get_noise (snd x) noise_st in - (mk_env_entry x (update_val_gradenv grad_env x alpha val noise), nst). - - Fixpoint list_arg_iter {A B} (f: A -> B -> A * B) - (l:list A) (b: B) : (list A)*B := - match l with - | nil => (l, b) - | a :: l' => - let '(na, b') := f a b in - let '(nl, nb) := list_arg_iter f l' b' in - (na::nl, nb) - end. - - Definition harmonic_lr (step : nat) : float := 1 / (FfromZ (Z.of_nat (S step))). - - Definition optimize_step_backprop - (step : nat) (df : UnitDefinedFunction DTfloat) (σ:df_env) (lr : nat -> float) - (noise_st : Stream float) (dvars : list var_type) - : (option df_env)*(Stream float) := - match df_eval_backprop_deriv σ df (gradenv_init dvars) 1 with - | Some gradenv => - let alpha := lr step in - let '(env, nst) := list_arg_iter (fun a b => update_entry a gradenv alpha b) σ noise_st in - (Some env, nst) - | _ => (None, noise_st) - end. - - Definition optimize_step_tree_backprop - (step : nat) (df : UnitDefinedFunction DTfloat) (σ:df_env) (lr : nat -> float) - (noise_st : Stream float) (dvars : list var_type) - : (option df_env)*(Stream float) := - match df_eval_tree σ df with - | Some df_tree => - match df_eval_tree_backprop_deriv (* σ not-needed *) nil df_tree (gradenv_init dvars) 1 with - | Some gradenv => - let alpha := lr step in - let '(env, nst) := list_arg_iter (fun a b => update_entry a gradenv alpha b) σ noise_st in - (Some env, nst) - | _ => (None, noise_st) - end - | _ => (None, noise_st) - end. - - Fixpoint optimize_steps - (start count:nat) (df : UnitDefinedFunction DTfloat) (σ:df_env) (lvar : list SubVar) - (noise_st : Stream float) : (option df_env)*(Stream float) := - match count with - | 0 => (Some σ, noise_st) - | S n => - match optimize_step start df σ lvar noise_st with - | (Some σ', noise_st') => optimize_steps (S start) n df σ' lvar noise_st' - | (None, noise_st') => (None, noise_st') - end - end. - - Fixpoint optimize_steps_backprop - (start count:nat) (df : UnitDefinedFunction DTfloat) (σ:df_env) (lr : nat -> float) - (noise_st : Stream float) (dvars : list var_type) - : (option df_env)*(Stream float) := - match count with - | 0 => (Some σ, noise_st) - | S n => - match optimize_step_backprop start df σ lr noise_st dvars with - | (Some σ', noise_st') => optimize_steps_backprop (S start) n df σ' lr noise_st' dvars - | (None, noise_st') => (None, noise_st') - end - end. - - Fixpoint optimize_steps_tree_backprop - (start count:nat) (df : UnitDefinedFunction DTfloat) (σ:df_env) (lr : nat -> float) - (noise_st : Stream float) (dvars : list var_type) - : (option df_env)*(Stream float) := - match count with - | 0 => (Some σ, noise_st) - | S n => - match optimize_step_tree_backprop start df σ lr noise_st dvars with - | (Some σ', noise_st') => optimize_steps_tree_backprop (S start) n df σ' lr noise_st' dvars - | (None, noise_st') => (None, noise_st') - end - end. - - -Definition Fmax (a b:float) : float := - if Fgt a b then a else b. - -Definition Fmin (a b:float) : float := - if Flt a b then a else b. - -Definition vmax {n : nat} (vec : Vector float n) - := vector_fold_right1 Fmax (FfromZ 0) id vec. - -Definition vmin {n : nat} (vec : Vector float n) - := vector_fold_right1 Fmin (FfromZ 0) id vec. - -Definition nrows {A} (l : list (list A)) := List.length l. - -Definition ncols {A} (l : list (list A)) := - match l with - | nil => 0%nat - | r :: _ => List.length r - end. - -Definition normalizeIntData (l:list (list Z)) : Matrix float (nrows l) (ncols l) := - let mat : Matrix float (nrows l) (ncols l) := - fun i j => FfromZ (List.nth (proj1_sig j) (List.nth (proj1_sig i) l nil) (0)%Z) in - let tmat := transpose mat in - let maxes := vmap vmax tmat in - let mins := vmap vmin tmat in - fun i j => ((mat i j) - (mins j))/((maxes j)- (mins j)). - -Program Definition splitLastCol {nsamp ncols : nat} (data : Matrix float nsamp ncols) - (pf : (ncols > 0)%nat) - : (Matrix float (ncols-1) nsamp) * (Matrix float 1 nsamp) := - (fun i j => data j i, fun i j => data j (i + ncols-1)%nat). -Next Obligation. - lia. -Defined. -Next Obligation. - lia. -Defined. - -Definition init_env2 (dim1 dim2 dim3 : nat) (w b : string) - (ranm1 : Matrix float dim2 dim1) - (ranm2 : Matrix float dim3 dim2) : df_env := - let wvar := Name w in - let bvar := Name b in - let wvar1 := (Sub wvar 1, DTMatrix dim2 dim1) in - let wvar2 := (Sub wvar 2, DTMatrix dim3 dim2) in - let bvar1 := (Sub bvar 1, DTVector dim2) in - let bvar2 := (Sub bvar 2, DTVector dim3) in - (* (mk_env_entry wvar1 ranm1) :: (mk_env_entry wvar2 ranm2) :: *) - (mk_env_entry wvar1 (fun i j => (ranm1 i j) / (Fsqrt (FfromZ (Z.of_nat dim1))))) :: - (mk_env_entry wvar2 (fun i j => (ranm2 i j) / (Fsqrt (FfromZ (Z.of_nat dim2))))) :: - (mk_env_entry bvar1 (ConstVector dim2 0)) :: - (mk_env_entry bvar2 (ConstVector dim3 0)) :: nil. - - Definition mkNN_wisconsin (nsamp:nat) (ivar wvar bvar f1v f2v : SubVar) : (UnitDefinedFunction (DTMatrix 1 nsamp) * list var_type) := - let f1_activ := Max tt (Var (f1v,DTfloat) tt) (Number tt 0) in - let f2_activ := Sigmoid f2v in - mkNN2_Var_Mat_bias 9 15 1 nsamp ivar wvar bvar f1_activ f2_activ f1v f2v. - - Program Definition eval_wisconsin_batch (nsamp : nat) - (σ:df_env) - (normaldata: Matrix float nsamp 10) : option float := - let ivar := (Name "i") in - let flnnv := (Name "NNv") in - let outnnv := (Name "outnnv") in - let '(NN, dvars) := mkNN_wisconsin nsamp ivar (Name "w") (Name "b") (Name "f1v") (Name "f2v") in - @EvalNNinstancebatch 9 nsamp 1 ivar (CrossEntropy flnnv outnnv) flnnv outnnv - NN σ (splitLastCol normaldata _). - Next Obligation. - lia. - Qed. - -Program Definition wisconsin_instance_batch (nsamp : nat) - (σ:df_env) - (normaldata: Matrix float nsamp 10) - : (df_env * (UnitDefinedFunction DTfloat)) * (list var_type) := - let ivar := (Name "i") in - let flnnv := (Name "NNv") in - let outnnv := (Name "outnnv") in - let '(NN, dvars) := mkNN_wisconsin nsamp ivar (Name "w") (Name "b") (Name "f1v") (Name "f2v") in - (@NNinstancebatch 9 nsamp 1 ivar (CrossEntropy flnnv outnnv) flnnv outnnv - NN - σ - (splitLastCol normaldata _), dvars). -Next Obligation. - lia. -Qed. - -CoFixpoint zeronoise : Stream float := Cons 0 zeronoise. - -(* -https://towardsdatascience.com/predict-malignancy-in-breast-cancer-tumors-with-your-own-neural-network-and-the-wisconsin-dataset-76271a05e941 -*) - -Definition wisconsin_test (nsamp count : nat) - (σ:df_env) - (normaldata: Matrix float nsamp 10): list float := - let '(nninst, dvars) := wisconsin_instance_batch nsamp σ normaldata in - let lr := fun _ => 1 / (FfromZ 100) in - let onenv := fst (optimize_steps_tree_backprop 0 count (snd nninst) (fst nninst) lr - zeronoise dvars) in - match onenv with - | Some nenv => match df_eval nenv (snd nninst) with - | Some val => val :: nil - | _ => nil - end - | _ => nil - end. - -Definition wisconsin_test_env (nsamp count : nat) - (σ:df_env) - (normaldata: Matrix float nsamp 10): df_env := - let '(nninst, dvars) := wisconsin_instance_batch nsamp σ normaldata in - let lr := fun _ => 1 / (FfromZ 100) in - let onenv := fst (optimize_steps_tree_backprop 0 count (snd nninst) (fst nninst) lr - zeronoise dvars) in - match onenv with - | Some nenv => nenv - | _ => nil - end. - - -Example xvar:var_type := (Name "x", DTfloat). -Example xfun:UnitDefinedFunction DTfloat := Var xvar tt. -Example tquad:UnitDefinedFunction DTfloat := Times tt xfun xfun. -Example quad:UnitDefinedFunction DTfloat := Minus tt (Times tt xfun xfun) (Number tt 1). -Example squad:UnitDefinedFunction DTfloat := Minus tt (Square tt xfun) (Number tt 1). -Example env : df_env := (mk_env_entry xvar (FfromZ 5))::nil. - - -Example gradenv := match df_eval_backprop_deriv env quad (gradenv_init (xvar :: nil)) 1 with - | Some gradenv => gradenv - | _ => nil - end. - -Example gradenv_tree := - match df_eval_tree env quad with - | Some df_tree => - match df_eval_tree_backprop_deriv nil df_tree (gradenv_init (xvar :: nil)) 1 with - | Some gradenv => gradenv - | _ => nil - end - | _ => nil - end. - -Example wisconsin_gradenv (nsamp : nat) - (σ:df_env) - (normaldata: Matrix float nsamp 10) : df_env := - let '((env,nn),dvars) := wisconsin_instance_batch nsamp σ normaldata in - match df_eval_backprop_deriv env nn (gradenv_init dvars) 1 with - | Some gradenv => gradenv - | _ => nil - end. - -Example wisconsin_gradenv_tree (nsamp : nat) - (σ:df_env) - (normaldata: Matrix float nsamp 10) : df_env := - let '((env,nn),dvars) := wisconsin_instance_batch nsamp σ normaldata in - match df_eval_tree env nn with - | Some df_tree => - match df_eval_tree_backprop_deriv nil df_tree (gradenv_init dvars) 1 with - | Some gradenv => gradenv - | _ => nil - end - | _ => nil - end. - - Definition mkperceptron (n1 n2 : nat) (ivar wvar bvar : SubVar) (f_activ : UnitDefinedFunction DTfloat) (f_activ_var : SubVar) : (UnitDefinedFunction (DTVector n2) * list var_type) := - let mat1 := (wvar, DTMatrix n2 n1) in - let b1 := (bvar, DTVector n2) in - let ivec := Var (ivar, DTVector n1) tt in - let N1 := mkNN_bias_step n1 n2 ivec (Var mat1 tt) (Var b1 tt) f_activ_var f_activ in - (N1, mat1 :: b1 :: nil). - - Definition mkNN_test1 := - let ivar := Name "i" in - let wvar := Name "w" in - let bvar := Name "b" in - let f1v := Name "f1" in - let f1_activ := Max tt (Var (f1v,DTfloat) tt) (Number tt 0) in - let '(nn, dvars) := mkperceptron 2 2 ivar wvar bvar f1_activ f1v in - let winit := mk_env_entry (wvar, DTMatrix 2 2) (ConstMatrix 2 2 1) in - let binit := mk_env_entry (bvar, DTVector 2) (ConstVector 2 1) in - let datain := ConstVector 2 1 in - let env := (mk_env_entry (ivar, DTVector 2) (ConstVector 2 1)):: winit :: binit :: nil in - (env, nn). - - Definition mkNN_test := - let ivar := Name "i" in - let wvar := Name "w" in - let bvar := Name "b" in - let f1v := Name "f1" in -(* let f1_activ := Max tt (Var (f1v,DTfloat) tt) (Number tt 0) in *) - let f1_activ := Sigmoid f1v in - let '(nn, dvars) := mkperceptron 2 2 ivar wvar bvar f1_activ f1v in - let datain := ConstVector 2 1 in - let dataout := ConstVector 2 0 in - let loss_nnvar := Name "lv" in - let loss_ovar := Name "ov" in - let floss := CrossEntropy loss_nnvar loss_ovar in - let winit := mk_env_entry (wvar, DTMatrix 2 2) (ConstMatrix 2 2 (FfromZ 2)) in - let binit := mk_env_entry (bvar, DTVector 2) (ConstVector 2 (FfromZ 3)) in - let env := winit :: binit :: nil in - (NNinstance1samp ivar floss loss_nnvar loss_ovar nn env (datain, dataout), dvars). - -Definition NN_test (count : nat) : list float := - let '(nninst, dvars) := mkNN_test in - let onenv := fst (optimize_steps_tree_backprop 0 count (snd nninst) (fst nninst) harmonic_lr - zeronoise dvars) in - match onenv with - | Some nenv => match df_eval nenv (snd nninst) with - | Some val => val :: nil - | _ => nil - end - | _ => nil - end. - -Example NN_test_gradenv : df_env := - let '((env, nn), dvars) := mkNN_test in - match df_eval_backprop_deriv env nn (gradenv_init dvars) 1 with - | Some gradenv => gradenv - | _ => nil - end. - -Example NN_test_gradenv_tree :df_env := - let '((env, nn), dvars) := mkNN_test in - match df_eval_tree env nn with - | Some df_tree => - match df_eval_tree_backprop_deriv nil df_tree (gradenv_init dvars) 1 with - | Some gradenv => gradenv - | _ => nil - end - | _ => nil - end. - - -Definition NN_test_env (count : nat) : df_env := - let '(nninst, dvars) := mkNN_test in - let onenv := fst (optimize_steps_tree_backprop 0 count (snd nninst) (fst nninst) - harmonic_lr zeronoise dvars) in - match onenv with - | Some nenv => nenv - | _ => nil - end. - -Definition NN_test_NN := - let '((env,nn), dvars) := mkNN_test in - nn. - -Definition NN_test_val : list float := - let '((env,nn), dvars) := mkNN_test in - match df_eval env nn with - | Some val => val :: nil - | _ => nil - end. - -Definition test_update_val_gradenv - := update_val_gradenv gradenv xvar 1 (FfromZ 5) 0. - -Definition test_update : df_env := - (fst (update_entry (mk_env_entry xvar (FfromZ 5)) gradenv 1 zeronoise))::nil. - -Definition test_optimize_step_backprop - (step : nat) (df : UnitDefinedFunction DTfloat) (σ:df_env) - (noise_st : Stream float) (dvars : list var_type) : df_env := - match fst (optimize_step_backprop step df σ harmonic_lr noise_st dvars) with - | Some env => env - | _ => nil - end. - -Definition test_optimize_step_tree_backprop - (step : nat) (df : UnitDefinedFunction DTfloat) (σ:df_env) - (noise_st : Stream float) (dvars : list var_type) : df_env := - match fst (optimize_step_tree_backprop step df σ harmonic_lr noise_st dvars) with - | Some env => env - | _ => nil - end. - -Example testopt := test_optimize_step_backprop 0 quad env zeronoise (xvar :: nil). -Example testreeopt := test_optimize_step_tree_backprop 0 quad env zeronoise (xvar :: nil). - -Example opt := fst (optimize_steps 0 2 quad env ((fst xvar) :: nil) zeronoise). -Example opt2 := fst (optimize_steps_tree_backprop 0 2 quad env harmonic_lr zeronoise (xvar :: nil)). - -Example val := 1+1. - -End GenNN. - -(* -Local Instance floatish_interval : floatish := floatish_interval_gen 53. -Compute df_eval NN_test_env NN_test_NN. -*) diff --git a/coq/NeuralNetworks/NN.v b/coq/NeuralNetworks/NN.v deleted file mode 100644 index 5ccf88c5..00000000 --- a/coq/NeuralNetworks/NN.v +++ /dev/null @@ -1,59 +0,0 @@ -Require Import Reals.Rbase. -Require Import Reals.Rfunctions. -Require Import Arith. -Require Import String. -Require Import Vector. - -Require Import AxiomaticNormedRealVectorSpace. - -Module NN. - -Section SeriesDivergence. - -Definition converges (s: nat -> R) := - exists sum:R, infinite_sum s sum. - -Definition diverges (s: nat -> R) : Prop := - ~(converges s). - -Definition diverges_right (s: nat -> R) : Prop := - forall m: R, - exists N: nat, - forall n: nat, - n >= N -> Rgt (s n) m. - -Definition diverges_left (s: nat -> R) : Prop := - forall m: R, - exists N: nat, - forall n:nat, - n >= N -> Rlt (s n) m. - -End SeriesDivergence. - -Section AssumptionC. - -Local Open Scope R_scope. -Definition Assumption_C_1 (ak : nat -> R) : Prop := - let - ak_squared (n : nat) := (ak n)^2 - in - (forall n, ak n >= 0) /\ - diverges_right ak /\ - converges ak_squared. - -Definition Assumption_C_2 (s: Set) (x: nat -> rvector s) : Prop := - exists M : R, - forall k: nat, norm s (x k) < M. - -(* TODO *) -Definition Assumption_C_3 (zeta: nat -> R) : Prop := - ZeroMeanBoundedVariance zeta. - -Definition Assumption_C (s: Set) (zeta: nat -> R) (alpha: nat -> R) (x : nat -> rvector s) : Prop := - Assumption_C_1 alpha /\ Assumption_C_2 s x /\ Assumption_C_3 zeta. - -Local Close Scope R_scope. - -End AssumptionC. - -End NN. diff --git a/coq/NeuralNetworks/nDefinedFunctions.v b/coq/NeuralNetworks/nDefinedFunctions.v deleted file mode 100644 index d5e3f8b6..00000000 --- a/coq/NeuralNetworks/nDefinedFunctions.v +++ /dev/null @@ -1,15740 +0,0 @@ -Require Import String. -Require Import EquivDec. -Require Import RelationClasses. -Require Import List. -Require Import Permutation. -Require Import NPeano. -Require Import Lra Lia. -Require Reals. -Require Import Eqdep_dec. - -Require Import Floatish. -Require Import Utils. -Require Import derivlemmas. -Require VectorDef. -Require Import nvector. - -Set Bullet Behavior "Strict Subproofs". - -Section DefinedFunctions. - - Context {floatish_impl:floatish}. - Local Open Scope float. - -(* Declare Scope df_scope. *) - -(* in pytorch relu(f)' if f <=0 then 0 else f' *) -(* in pytorch abs(f)' = f'*sign(f) *) -(* max(a,b)' = if a<=b then b' else a' *) -(* min(a,b)' = if a>=b then b' else a' *) -(* x = Variable(torch.tensor(0.0), requires_grad=True) *) -(* z = torch.min(x*x, x); z.backward(); print(x.grad) = 1 *) -(* x.grad.data.zero_() between tests *) -(* relu behaves like max(x, 0), not max(0,x), i.e. relu(x)' at 0 = 0 *) - - - Section Definitions. - - Definition var := string. - - Inductive SubVar : Set := - | Name (s : string) - | Sub (v : SubVar) (i : nat). - - - Definition var_dec : forall v1 v2 : SubVar, {v1 = v2} + {v1 <> v2}. - Proof. - decide equality. - apply string_dec. - apply Nat.eq_dec. - Defined. - - Global Instance var_eqdec : EqDec SubVar eq. - Proof. - intros x y. - apply var_dec. - Defined. - - (* A subset of defined functions *) - - - Inductive definition_function_types - := DTfloat - | DTVector (n:nat) - | DTMatrix (m n:nat). - - Definition definition_function_types_interp (dft:definition_function_types) : Type - := match dft with - | DTfloat => float - | DTVector n => vector float n - | DTMatrix m n => matrix float m n - end. - - Inductive data_type : definition_function_types -> Type - := DataFloat : data_type DTfloat - | DataVector n (v:vector float n) : data_type (DTVector n) - | DataMatrix m n (mat:matrix float m n) : data_type (DTMatrix m n). - - Definition var_type := (SubVar * definition_function_types)%type. - - Definition definition_function_types_dec : forall v1 v2 : definition_function_types, {v1 = v2} + {v1 <> v2}. - Proof. - decide equality; apply Nat.eq_dec. - Defined. - - Definition vart_dec : forall v1 v2 : var_type, {v1 = v2} + {v1 <> v2}. - Proof. - decide equality. - - apply definition_function_types_dec. - - apply var_dec. - Defined. - - Global Instance vart_eqdec : EqDec var_type eq. - Proof. - intros ??. - apply vart_dec. - Defined. - - Lemma var_type_UIP_refl {x:var_type} (e:x=x) : e = eq_refl x. - Proof. - apply (UIP_dec vart_dec). - Qed. - - Lemma definition_function_types_UIP_refl {x:definition_function_types} (e:x=x) : e = eq_refl x. - Proof. - apply (UIP_dec definition_function_types_dec). - Qed. - - Definition env_entry_type := {v:var_type & definition_function_types_interp (snd v)}. - Definition df_env := list env_entry_type. - - Definition mk_env_entry v e : env_entry_type - := let P := fun xv => definition_function_types_interp (snd xv) in - existT P v e. - - Definition UnitAnn: definition_function_types->Type := fun _ => unit. - Definition EvalAnn: definition_function_types->Type := definition_function_types_interp. - - Inductive DefinedFunction {Ann:definition_function_types->Type} : definition_function_types -> Type := - | Number (ann:Ann DTfloat) (x : float) : DefinedFunction DTfloat - | Constant {t:definition_function_types} (ann:Ann t) (x : definition_function_types_interp t) : DefinedFunction t - - | DVector {n} (ann:Ann (DTVector n)) (x : vector (DefinedFunction DTfloat) n) : DefinedFunction (DTVector n) - - | DMatrix {n m} (ann:Ann (DTMatrix n m)) (x : matrix (DefinedFunction DTfloat) n m) : DefinedFunction (DTMatrix n m) - - - | Var (v : var_type) (ann: Ann (snd v)) : DefinedFunction (snd v) - | Plus (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat - | Minus (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat - | Times (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat - | Divide (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat - | Square (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat - | Exp (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat - | Log (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat - | Abs (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat - | Sign (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat - | PSign (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat - | Max (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat - | VectorDot {n} (ann:Ann DTfloat) (l r: DefinedFunction (DTVector n)) : DefinedFunction DTfloat - | VectorSum {n} (ann:Ann DTfloat) (v: DefinedFunction (DTVector n)) : DefinedFunction DTfloat - | MatrixSum {m n} (ann:Ann DTfloat) (v: DefinedFunction (DTMatrix m n)) : DefinedFunction DTfloat - | VectorElem {n} (ann:Ann DTfloat) (l:DefinedFunction (DTVector n)) (i:{x:nat|x Prop) - (f : forall (ann : UnitAnn DTfloat) (x : float), - P DTfloat (Number ann x)) - (f0 : forall (t : definition_function_types) - (ann : UnitAnn t) (x : definition_function_types_interp t), P t (Constant ann x)) - (f1 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (x : vector (DefinedFunction UnitAnn DTfloat) n) - (f: vforall (P DTfloat) x), - P (DTVector n) (DVector ann x)) - (f2 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m)) - (x : matrix (DefinedFunction UnitAnn DTfloat) n m) - (f: mforall (P DTfloat) x), - P (DTMatrix n m) (DMatrix ann x)) - (f3 : forall (v : var_type) (ann : UnitAnn (snd v)), - P (snd v) (Var v ann)) - (f4 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Plus ann l r)) - (f5 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Minus ann l r)) - (f6 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Times ann l r)) - (f7 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Divide ann l r)) - (f8 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Square ann e)) - (f9 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Exp ann e)) - (f10 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Log ann e)) - (f11 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Abs ann e)) - (f12 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Sign ann e)) - (f13 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (PSign ann e)) - (f14 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Max ann l r)) - (f15 : forall (n : nat) (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn (DTVector n)), - P (DTVector n) l -> - forall r : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) r -> P DTfloat (VectorDot ann l r)) - (f16 : forall (n : nat) (ann : UnitAnn DTfloat) - (v : DefinedFunction UnitAnn (DTVector n)), P (DTVector n) v -> P DTfloat (VectorSum ann v)) - (f17 : forall (m n : nat) (ann : UnitAnn DTfloat) - (v : DefinedFunction UnitAnn (DTMatrix m n)), - P (DTMatrix m n) v -> P DTfloat (MatrixSum ann v)) - (f18 : forall (n : nat) (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn (DTVector n)), - P (DTVector n) l -> - forall i : {x : nat | (x < n)%nat}, P DTfloat (VectorElem ann l i)) - (f19 : forall (m n : nat) (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn (DTMatrix m n)), - P (DTMatrix m n) l -> - forall (i : {x : nat | (x < m)%nat}) (j : {x : nat | (x < n)%nat}), - P DTfloat (MatrixElem ann l i j)) - (f20 : forall (m n : nat) (ann : UnitAnn (DTVector m)) - (l : DefinedFunction UnitAnn (DTMatrix m n)), - P (DTMatrix m n) l -> - forall r : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) r -> P (DTVector m) (MatrixVectorMult ann l r)) - (f21 : forall (m n : nat) (ann : UnitAnn (DTMatrix m n)) - (l : DefinedFunction UnitAnn (DTMatrix m n)), - P (DTMatrix m n) l -> - forall r : DefinedFunction UnitAnn (DTVector m), - P (DTVector m) r -> P (DTMatrix m n) (MatrixVectorAdd ann l r)) - (f22 : forall (m p n : nat) (ann : UnitAnn (DTMatrix m n)) - (l : DefinedFunction UnitAnn (DTMatrix m p)), - P (DTMatrix m p) l -> - forall r : DefinedFunction UnitAnn (DTMatrix p n), - P (DTMatrix p n) r -> P (DTMatrix m n) (MatrixMult ann l r)) - (f23 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (l : DefinedFunction UnitAnn (DTVector n)), - P (DTVector n) l -> - forall r : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) r -> P (DTVector n) (VectorPlus ann l r)) - (f24 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (l : DefinedFunction UnitAnn (DTVector n)), - P (DTVector n) l -> - forall r : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) r -> P (DTVector n) (VectorMinus ann l r)) - (f25 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m)) - (l : DefinedFunction UnitAnn (DTMatrix n m)), - P (DTMatrix n m) l -> - forall r : DefinedFunction UnitAnn (DTMatrix n m), - P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixPlus ann l r)) - (f26 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m)) - (l : DefinedFunction UnitAnn (DTMatrix n m)), - P (DTMatrix n m) l -> - forall r : DefinedFunction UnitAnn (DTMatrix n m), - P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixMinus ann l r)) - (f27 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (x : DefinedFunction UnitAnn DTfloat), - P DTfloat x -> - forall l : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) l -> P (DTVector n) (VectorScalMult ann x l)) - (f28 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m)) - (x : DefinedFunction UnitAnn DTfloat), - P DTfloat x -> - forall l : DefinedFunction UnitAnn (DTMatrix n m), - P (DTMatrix n m) l -> P (DTMatrix n m) (MatrixScalMult ann x l)) - (f29 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (v : SubVar) (s : DefinedFunction UnitAnn DTfloat), - P DTfloat s -> - forall l : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) l -> P (DTVector n) (VectorApply ann v s l)) - (f30 : forall (m n : nat) (ann : UnitAnn (DTMatrix m n)) - (v : SubVar) (s : DefinedFunction UnitAnn DTfloat), - P DTfloat s -> - forall l : DefinedFunction UnitAnn (DTMatrix m n), - P (DTMatrix m n) l -> P (DTMatrix m n) (MatrixApply ann v s l)) - (f31 : forall (n : nat) (ann : UnitAnn DTfloat) - (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat), - P DTfloat s -> - forall l : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) l -> forall r : vector float n, P DTfloat (VLossfun ann v1 v2 s l r)) - (f32 : forall (m n : nat) (ann : UnitAnn DTfloat) - (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat), - P DTfloat s -> - forall l : DefinedFunction UnitAnn (DTMatrix m n), - P (DTMatrix m n) l -> - forall r : matrix float m n, P DTfloat (MLossfun ann v1 v2 s l r)) - := -fix -F (d : definition_function_types) - (d0 : DefinedFunction UnitAnn d) {struct d0} : P d d0 := - match d0 as d2 in (DefinedFunction _ d1) return (P d1 d2) with - | Number ann x => f ann x - | @Constant _ t ann x => f0 t ann x - | @DVector _ n ann x => - f1 n ann x - ((fix F1 n (x:vector (DefinedFunction UnitAnn DTfloat) n) : vforall (P DTfloat) x := - match x with - | Vector.nil => Vector.Forall_nil (P DTfloat) - | Vector.cons h _ tl => Vector.Forall_cons _ _ _ (F DTfloat h) (F1 _ tl) - end) n x) - | @DMatrix _ n m ann x => - f2 n m ann x - ((fix F2 n m (x:matrix (DefinedFunction UnitAnn DTfloat) n m) : mforall (P DTfloat) x := - match x with - | Vector.nil => Vector.Forall_nil (vforall (P DTfloat)) - - | Vector.cons h _ tl => - Vector.Forall_cons _ _ _ - ((fix F1 m (x:vector (DefinedFunction UnitAnn DTfloat) m) : vforall (P DTfloat) x := - match x with - | Vector.nil => Vector.Forall_nil (P DTfloat) - | Vector.cons h _ tl => Vector.Forall_cons _ _ _ (F DTfloat h) (F1 _ tl) - end) m h) - (F2 _ _ tl) - end) n m x) - | Var v ann => f3 v ann - | Plus ann l r => f4 ann l (F DTfloat l) r (F DTfloat r) - | Minus ann l r => f5 ann l (F DTfloat l) r (F DTfloat r) - | Times ann l r => f6 ann l (F DTfloat l) r (F DTfloat r) - | Divide ann l r => f7 ann l (F DTfloat l) r (F DTfloat r) - | Square ann e => f8 ann e (F DTfloat e) - | Exp ann e => f9 ann e (F DTfloat e) - | Log ann e => f10 ann e (F DTfloat e) - | Abs ann e => f11 ann e (F DTfloat e) - | Sign ann e => f12 ann e (F DTfloat e) - | PSign ann e => f13 ann e (F DTfloat e) - | Max ann l r => f14 ann l (F DTfloat l) r (F DTfloat r) - | @VectorDot _ n ann l r => f15 n ann l (F (DTVector n) l) r (F (DTVector n) r) - | @VectorSum _ n ann v => f16 n ann v (F (DTVector n) v) - | @MatrixSum _ m n ann v => f17 m n ann v (F (DTMatrix m n) v) - | @VectorElem _ n ann l i => f18 n ann l (F (DTVector n) l) i - | @MatrixElem _ m n ann l i j => f19 m n ann l (F (DTMatrix m n) l) i j - | @MatrixVectorMult _ m n ann l r => - f20 m n ann l (F (DTMatrix m n) l) r (F (DTVector n) r) - | @MatrixVectorAdd _ m n ann l r => - f21 m n ann l (F (DTMatrix m n) l) r (F (DTVector m) r) - | @MatrixMult _ m p n ann l r => - f22 m p n ann l (F (DTMatrix m p) l) r (F (DTMatrix p n) r) - | @VectorPlus _ n ann l r => f23 n ann l (F (DTVector n) l) r (F (DTVector n) r) - | @VectorMinus _ n ann l r => f24 n ann l (F (DTVector n) l) r (F (DTVector n) r) - | @MatrixPlus _ n m ann l r => f25 n m ann l (F (DTMatrix n m) l) r (F (DTMatrix n m) r) - | @MatrixMinus _ n m ann l r => - f26 n m ann l (F (DTMatrix n m) l) r (F (DTMatrix n m) r) - | @VectorScalMult _ n ann x l => f27 n ann x (F DTfloat x) l (F (DTVector n) l) - | @MatrixScalMult _ n m ann x l => f28 n m ann x (F DTfloat x) l (F (DTMatrix n m) l) - | @VectorApply _ n ann v s l => f29 n ann v s (F DTfloat s) l (F (DTVector n) l) - | @MatrixApply _ m n ann v s l => - f30 m n ann v s (F DTfloat s) l (F (DTMatrix m n) l) - | @VLossfun _ n ann v1 v2 s l r => - f31 n ann v1 v2 s (F DTfloat s) l (F (DTVector n) l) r - | @MLossfun _ m n ann v1 v2 s l r => - f32 m n ann v1 v2 s (F DTfloat s) l (F (DTMatrix m n) l) r - end. - -Definition DefinedFunction_ind_simpl {Ann} - (P : forall (d : definition_function_types), DefinedFunction Ann d -> Prop) - (f : forall (ann : Ann DTfloat) (x : float), - P DTfloat (Number ann x)) - (f0 : forall (t : definition_function_types) - (ann : Ann t) (x : definition_function_types_interp t), P t (Constant ann x)) - (f1 : forall (n : nat) (ann : Ann (DTVector n)) - (x : vector (DefinedFunction Ann DTfloat) n) - (f: vforall (P DTfloat) x), - P (DTVector n) (DVector ann x)) - (f2 : forall (n m : nat) (ann : Ann (DTMatrix n m)) - (x : matrix (DefinedFunction Ann DTfloat) n m) - (f: mforall (P DTfloat) x), - P (DTMatrix n m) (DMatrix ann x)) - (f3 : forall (v : var_type) (ann : Ann (snd v)), - P (snd v) (Var v ann)) - (f4 : forall (ann : Ann DTfloat) - (l : DefinedFunction Ann DTfloat), - P DTfloat l -> - forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Plus ann l r)) - (f5 : forall (ann : Ann DTfloat) - (l : DefinedFunction Ann DTfloat), - P DTfloat l -> - forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Minus ann l r)) - (f6 : forall (ann : Ann DTfloat) - (l : DefinedFunction Ann DTfloat), - P DTfloat l -> - forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Times ann l r)) - (f7 : forall (ann : Ann DTfloat) - (l : DefinedFunction Ann DTfloat), - P DTfloat l -> - forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Divide ann l r)) - (f8 : forall (ann : Ann DTfloat) - (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Square ann e)) - (f9 : forall (ann : Ann DTfloat) - (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Exp ann e)) - (f10 : forall (ann : Ann DTfloat) - (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Log ann e)) - (f11 : forall (ann : Ann DTfloat) - (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Abs ann e)) - (f12 : forall (ann : Ann DTfloat) - (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Sign ann e)) - (f13 : forall (ann : Ann DTfloat) - (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (PSign ann e)) - (f14 : forall (ann : Ann DTfloat) - (l : DefinedFunction Ann DTfloat), - P DTfloat l -> - forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Max ann l r)) - (f15 : forall (n : nat) (ann : Ann DTfloat) - (l : DefinedFunction Ann (DTVector n)), - P (DTVector n) l -> - forall r : DefinedFunction Ann (DTVector n), - P (DTVector n) r -> P DTfloat (VectorDot ann l r)) - (f16 : forall (n : nat) (ann : Ann DTfloat) - (v : DefinedFunction Ann (DTVector n)), P (DTVector n) v -> P DTfloat (VectorSum ann v)) - (f17 : forall (m n : nat) (ann : Ann DTfloat) - (v : DefinedFunction Ann (DTMatrix m n)), - P (DTMatrix m n) v -> P DTfloat (MatrixSum ann v)) - (f18 : forall (n : nat) (ann : Ann DTfloat) - (l : DefinedFunction Ann (DTVector n)), - P (DTVector n) l -> - forall i : {x : nat | (x < n)%nat}, P DTfloat (VectorElem ann l i)) - (f19 : forall (m n : nat) (ann : Ann DTfloat) - (l : DefinedFunction Ann (DTMatrix m n)), - P (DTMatrix m n) l -> - forall (i : {x : nat | (x < m)%nat}) (j : {x : nat | (x < n)%nat}), - P DTfloat (MatrixElem ann l i j)) - (f20 : forall (m n : nat) (ann : Ann (DTVector m)) - (l : DefinedFunction Ann (DTMatrix m n)), - P (DTMatrix m n) l -> - forall r : DefinedFunction Ann (DTVector n), - P (DTVector n) r -> P (DTVector m) (MatrixVectorMult ann l r)) - (f21 : forall (m n : nat) (ann : Ann (DTMatrix m n)) - (l : DefinedFunction Ann (DTMatrix m n)), - P (DTMatrix m n) l -> - forall r : DefinedFunction Ann (DTVector m), - P (DTVector m) r -> P (DTMatrix m n) (MatrixVectorAdd ann l r)) - (f22 : forall (m p n : nat) (ann : Ann (DTMatrix m n)) - (l : DefinedFunction Ann (DTMatrix m p)), - P (DTMatrix m p) l -> - forall r : DefinedFunction Ann (DTMatrix p n), - P (DTMatrix p n) r -> P (DTMatrix m n) (MatrixMult ann l r)) - (f23 : forall (n : nat) (ann : Ann (DTVector n)) - (l : DefinedFunction Ann (DTVector n)), - P (DTVector n) l -> - forall r : DefinedFunction Ann (DTVector n), - P (DTVector n) r -> P (DTVector n) (VectorPlus ann l r)) - (f24 : forall (n : nat) (ann : Ann (DTVector n)) - (l : DefinedFunction Ann (DTVector n)), - P (DTVector n) l -> - forall r : DefinedFunction Ann (DTVector n), - P (DTVector n) r -> P (DTVector n) (VectorMinus ann l r)) - (f25 : forall (n m : nat) (ann : Ann (DTMatrix n m)) - (l : DefinedFunction Ann (DTMatrix n m)), - P (DTMatrix n m) l -> - forall r : DefinedFunction Ann (DTMatrix n m), - P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixPlus ann l r)) - (f26 : forall (n m : nat) (ann : Ann (DTMatrix n m)) - (l : DefinedFunction Ann (DTMatrix n m)), - P (DTMatrix n m) l -> - forall r : DefinedFunction Ann (DTMatrix n m), - P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixMinus ann l r)) - (f27 : forall (n : nat) (ann : Ann (DTVector n)) - (x : DefinedFunction Ann DTfloat), - P DTfloat x -> - forall l : DefinedFunction Ann (DTVector n), - P (DTVector n) l -> P (DTVector n) (VectorScalMult ann x l)) - (f28 : forall (n m : nat) (ann : Ann (DTMatrix n m)) - (x : DefinedFunction Ann DTfloat), - P DTfloat x -> - forall l : DefinedFunction Ann (DTMatrix n m), - P (DTMatrix n m) l -> P (DTMatrix n m) (MatrixScalMult ann x l)) - (f29 : forall (n : nat) (ann : Ann (DTVector n)) - (v : SubVar) (s : DefinedFunction UnitAnn DTfloat), - forall l : DefinedFunction Ann (DTVector n), - P (DTVector n) l -> P (DTVector n) (VectorApply ann v s l)) - (f30 : forall (m n : nat) (ann : Ann (DTMatrix m n)) - (v : SubVar) (s : DefinedFunction UnitAnn DTfloat), - forall l : DefinedFunction Ann (DTMatrix m n), - P (DTMatrix m n) l -> P (DTMatrix m n) (MatrixApply ann v s l)) - (f31 : forall (n : nat) (ann : Ann DTfloat) - (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat), - forall l : DefinedFunction Ann (DTVector n), - P (DTVector n) l -> forall r : vector float n, P DTfloat (VLossfun ann v1 v2 s l r)) - (f32 : forall (m n : nat) (ann : Ann DTfloat) - (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat), - forall l : DefinedFunction Ann (DTMatrix m n), - P (DTMatrix m n) l -> - forall r : matrix float m n, P DTfloat (MLossfun ann v1 v2 s l r)) - := -fix -F (d : definition_function_types) - (d0 : DefinedFunction Ann d) {struct d0} : P d d0 := - match d0 as d2 in (DefinedFunction _ d1) return (P d1 d2) with - | Number ann x => f ann x - | @Constant _ t ann x => f0 t ann x - | @DVector _ n ann x => - f1 n ann x - ((fix F1 n (x:vector (DefinedFunction Ann DTfloat) n) : vforall (P DTfloat) x := - match x with - | Vector.nil => Vector.Forall_nil (P DTfloat) - | Vector.cons h _ tl => Vector.Forall_cons _ _ _ (F DTfloat h) (F1 _ tl) - end) n x) - | @DMatrix _ n m ann x => - f2 n m ann x - ((fix F2 n m (x:matrix (DefinedFunction Ann DTfloat) n m) : mforall (P DTfloat) x := - match x with - | Vector.nil => Vector.Forall_nil (vforall (P DTfloat)) - - | Vector.cons h _ tl => - Vector.Forall_cons _ _ _ - ((fix F1 m (x:vector (DefinedFunction Ann DTfloat) m) : vforall (P DTfloat) x := - match x with - | Vector.nil => Vector.Forall_nil (P DTfloat) - | Vector.cons h _ tl => Vector.Forall_cons _ _ _ (F DTfloat h) (F1 _ tl) - end) m h) - (F2 _ _ tl) - end) n m x) - | Var v ann => f3 v ann - | Plus ann l r => f4 ann l (F DTfloat l) r (F DTfloat r) - | Minus ann l r => f5 ann l (F DTfloat l) r (F DTfloat r) - | Times ann l r => f6 ann l (F DTfloat l) r (F DTfloat r) - | Divide ann l r => f7 ann l (F DTfloat l) r (F DTfloat r) - | Square ann e => f8 ann e (F DTfloat e) - | Exp ann e => f9 ann e (F DTfloat e) - | Log ann e => f10 ann e (F DTfloat e) - | Abs ann e => f11 ann e (F DTfloat e) - | Sign ann e => f12 ann e (F DTfloat e) - | PSign ann e => f13 ann e (F DTfloat e) - | Max ann l r => f14 ann l (F DTfloat l) r (F DTfloat r) - | @VectorDot _ n ann l r => f15 n ann l (F (DTVector n) l) r (F (DTVector n) r) - | @VectorSum _ n ann v => f16 n ann v (F (DTVector n) v) - | @MatrixSum _ m n ann v => f17 m n ann v (F (DTMatrix m n) v) - | @VectorElem _ n ann l i => f18 n ann l (F (DTVector n) l) i - | @MatrixElem _ m n ann l i j => f19 m n ann l (F (DTMatrix m n) l) i j - | @MatrixVectorMult _ m n ann l r => - f20 m n ann l (F (DTMatrix m n) l) r (F (DTVector n) r) - | @MatrixVectorAdd _ m n ann l r => - f21 m n ann l (F (DTMatrix m n) l) r (F (DTVector m) r) - | @MatrixMult _ m p n ann l r => - f22 m p n ann l (F (DTMatrix m p) l) r (F (DTMatrix p n) r) - | @VectorPlus _ n ann l r => f23 n ann l (F (DTVector n) l) r (F (DTVector n) r) - | @VectorMinus _ n ann l r => f24 n ann l (F (DTVector n) l) r (F (DTVector n) r) - | @MatrixPlus _ n m ann l r => f25 n m ann l (F (DTMatrix n m) l) r (F (DTMatrix n m) r) - | @MatrixMinus _ n m ann l r => - f26 n m ann l (F (DTMatrix n m) l) r (F (DTMatrix n m) r) - | @VectorScalMult _ n ann x l => f27 n ann x (F DTfloat x) l (F (DTVector n) l) - | @MatrixScalMult _ n m ann x l => f28 n m ann x (F DTfloat x) l (F (DTMatrix n m) l) - | @VectorApply _ n ann v s l => f29 n ann v s l (F (DTVector n) l) - | @MatrixApply _ m n ann v s l => - f30 m n ann v s l (F (DTMatrix m n) l) - | @VLossfun _ n ann v1 v2 s l r => - f31 n ann v1 v2 s l (F (DTVector n) l) r - | @MLossfun _ m n ann v1 v2 s l r => - f32 m n ann v1 v2 s l (F (DTMatrix m n) l) r - end. - - Definition get_annotation {Ann T} (df:DefinedFunction Ann T) : Ann T - := match df with - | Number ann _ => ann - | Constant _ ann _ => ann - | DVector _ ann _ => ann - | DMatrix _ _ ann _ => ann - | Var _ ann => ann - | Plus ann _ _ => ann - | Minus ann _ _ => ann - | Times ann _ _ => ann - | Divide ann _ _ => ann - | Square ann _ => ann - | Exp ann _ => ann - | Log ann _ => ann - | Abs ann _ => ann - | Sign ann _ => ann - | PSign ann _ => ann - | Max ann _ _ => ann - | VectorDot _ ann _ _ => ann - | VectorSum _ ann _ => ann - | MatrixSum _ _ ann _ => ann - | VectorElem _ ann _ _ => ann - | MatrixElem _ _ ann _ _ _ => ann - | MatrixVectorMult _ _ ann _ _ => ann - | MatrixVectorAdd _ _ ann _ _ => ann - | MatrixMult _ _ _ ann _ _ => ann - | VectorPlus _ ann _ _ => ann - | VectorMinus _ ann _ _ => ann - | MatrixPlus _ _ ann _ _ => ann - | MatrixMinus _ _ ann _ _ => ann - | VectorScalMult _ ann _ _ => ann - | MatrixScalMult _ _ ann _ _ => ann - | VectorApply _ ann _ _ _ => ann - | MatrixApply _ _ ann _ _ _ => ann - | VLossfun _ ann _ _ _ _ _ => ann - | MLossfun _ _ ann _ _ _ _ _ => ann - end. - - Definition dft_eq_dec : - forall (t1 t2 : definition_function_types), {t1 = t2} + {t1 <> t2}. - Proof. - decide equality. - decide equality. - apply Nat.eq_dec. - apply Nat.eq_dec. - Defined. - - Global Instance dft_eqdec : EqDec definition_function_types eq. - Proof. - intros ??. - apply dft_eq_dec. - Defined. - - End Definitions. - - Tactic Notation "DefinedFunction_cases" tactic(first) ident(c) := - first; - [ Case_aux c "Number"%string - | Case_aux c "Constant"%string - | Case_aux c "DVector"%string - | Case_aux c "DMatrix"%string - | Case_aux c "Var"%string - | Case_aux c "Plus"%string - | Case_aux c "Minus"%string - | Case_aux c "Times"%string - | Case_aux c "Divide"%string - | Case_aux c "Square"%string - | Case_aux c "Exp"%string - | Case_aux c "Log"%string - | Case_aux c "Abs"%string - | Case_aux c "Sign"%string - | Case_aux c "PSign"%string - | Case_aux c "Max"%string - | Case_aux c "VectorDot"%string - | Case_aux c "VectorSum"%string - | Case_aux c "MatrixSum"%string - | Case_aux c "VectorElem"%string - | Case_aux c "MatrixElem"%string - | Case_aux c "MatrixVectorMult"%string - | Case_aux c "MatrixVectorAdd"%string - | Case_aux c "MatrixMult"%string - | Case_aux c "VectorPlus"%string - | Case_aux c "VectorMinus"%string - | Case_aux c "MatrixPlus"%string - | Case_aux c "MatrixMinus"%string - | Case_aux c "VectorScalMult"%string - | Case_aux c "MatrixScalMult"%string - | Case_aux c "VectorApply"%string - | Case_aux c "MatrixApply"%string - | Case_aux c "VLossfun"%string - | Case_aux c "MLossfun"%string]. - - - Ltac refl_simpler := - repeat - match goal with - | [H: @eq var_type _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H) - | [H: @equiv var_type _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H) - | [H: @eq definition_function_types _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H) - | [H: @equiv definition_function_types _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H) - end. - - - Definition df_plus (df1 df2 : DefinedFunction UnitAnn DTfloat) : DefinedFunction UnitAnn DTfloat := - Plus tt df1 df2. - - Definition df_times (df1 df2 : DefinedFunction UnitAnn DTfloat) : DefinedFunction UnitAnn DTfloat := - Times tt df1 df2. - - Definition defined_sum {m} (v:vector (DefinedFunction UnitAnn DTfloat) m) : DefinedFunction UnitAnn DTfloat - := Vector.fold_right (fun a b => Plus tt a b) v (Number tt 0). - - Definition vsum {m:nat} (v:vector float m) : float - := Vector.fold_right Fplus v 0. - - Definition msum {m n:nat} (v:matrix float m n) : float := - vsum (vmap vsum v). - - Definition transpose {A} {n m:nat} (mat:matrix A n m) := - build_matrix (fun i j => mnth mat j i). - -(* defined in nvector - Definition matrix_vector_mult {n m} (l : Matrix float n m)(r : vector float m) : vector float n := - fun i => vsum (fun j => (l i j) * (r j)). - - Definition matrix_vector_add {n m} (l : Matrix float n m) (r : vector float n) : Matrix float n m := fun i j => (l i j) + (r i). - - Definition matrix_mult {n m p} (l : Matrix float n m)(r : Matrix float m p) : Matrix float n p := - fun i k => vsum (fun j => (l i j) * (r j k)). -*) - - Section deriv. - - - Section subst. - - Definition substvar {Ann} (v vv:var_type) (e':DefinedFunction Ann (snd v)) (e:DefinedFunction Ann (snd vv)) : (DefinedFunction Ann (snd vv)) := - match snd v == snd vv with - | left pf => eq_rect _ (fun t => DefinedFunction Ann t) e' _ pf - | right pf => e - end. - - Fixpoint df_subst {T Ann} (df: DefinedFunction Ann T) (v:var_type) (e':DefinedFunction UnitAnn (snd v)) {struct df} := - match df with - | Number _ x => Number tt x - | Constant t _ x => Constant tt x - | DVector n _ df => DVector tt (vmap (fun x => df_subst x v e') df) - | DMatrix n m _ df => DMatrix tt (mmap (fun x => df_subst x v e') df) - | Var vvar _ => substvar v vvar e' (Var vvar tt) - | Plus _ l r => Plus tt (df_subst l v e') (df_subst r v e') - | Times _ l r => Times tt (df_subst l v e') (df_subst r v e') - | Minus _ l r => Minus tt (df_subst l v e') (df_subst r v e') - | Divide _ l r => Divide tt (df_subst l v e') (df_subst r v e') - | Square _ e => Square tt (df_subst e v e') - | Exp _ e => Exp tt (df_subst e v e') - | Log _ e => Log tt (df_subst e v e') - | Abs _ e => Abs tt (df_subst e v e') - | Sign _ e => Sign tt (df_subst e v e') - | PSign _ e => PSign tt (df_subst e v e') - | Max _ l r => Max tt (df_subst l v e') (df_subst r v e') - | VectorElem n _ l i => VectorElem tt (df_subst l v e') i - | MatrixElem m n _ l i j => MatrixElem tt (df_subst l v e') i j - | VectorDot n _ l r => - VectorDot tt (df_subst l v e') (df_subst r v e') - | VectorSum n _ e => - VectorSum tt (df_subst e v e') - | MatrixSum n m _ e => - MatrixSum tt (df_subst e v e') - | VectorScalMult n _ x r => - VectorScalMult tt (df_subst x v e') (df_subst r v e') - | MatrixScalMult n m _ x r => - MatrixScalMult tt (df_subst x v e') (df_subst r v e') - | MatrixVectorMult n m _ l r => - MatrixVectorMult tt (df_subst l v e') (df_subst r v e') - | MatrixVectorAdd n m _ l r => - MatrixVectorAdd tt (df_subst l v e') (df_subst r v e') - | MatrixMult n m p _ l r => - MatrixMult tt (df_subst l v e') (df_subst r v e') - | VectorPlus n _ l r => - VectorPlus tt (df_subst l v e') (df_subst r v e') - | VectorMinus n _ l r => - VectorMinus tt (df_subst l v e') (df_subst r v e') - | MatrixPlus n m _ l r => - MatrixPlus tt (df_subst l v e') (df_subst r v e') - | MatrixMinus n m _ l r => - MatrixMinus tt (df_subst l v e') (df_subst r v e') - | VectorApply n _ x s l => - VectorApply tt x s (df_subst l v e') - | MatrixApply n m _ x s l => - MatrixApply tt x s (df_subst l v e') - | VLossfun n _ v1 v2 s l r => - VLossfun tt v1 v2 s (df_subst l v e') r - | MLossfun n m _ v1 v2 s l r => - MLossfun tt v1 v2 s (df_subst l v e') r - end. - - Definition df_substp {T Ann} := - fun e (ve':{v:var_type & DefinedFunction UnitAnn (snd v)}) => - @df_subst T Ann e (projT1 ve') (projT2 ve'). - - Definition df_subst_list {T} (e:DefinedFunction UnitAnn T) - (l:list {v:var_type & DefinedFunction UnitAnn (snd v)}) : DefinedFunction UnitAnn T - := fold_left (@df_substp T UnitAnn) l e. - - End subst. - - -(* restrict to scalar v? *) - Fixpoint df_deriv {T} (df:DefinedFunction UnitAnn T) (v:var_type) {struct df} : DefinedFunction UnitAnn T - := (match df with - | Number _ _ => Number tt 0 - | Constant t _ x => Constant tt - match t return definition_function_types_interp t with - | DTfloat => 0 - | DTVector n => ConstVector n 0 - | DTMatrix m n => ConstMatrix m n 0 - end - | DVector n _ df => DVector tt (vmap (fun x => df_deriv x v) df) - | DMatrix n m _ df => DMatrix tt (mmap (fun x => df_deriv x v) df) - | Var x _ => Constant tt - match snd x as y return definition_function_types_interp y with - | DTfloat => if x == v then 1 else 0 - | DTVector n => ConstVector n (if x == v then 1 else 0) - | DTMatrix m n => ConstMatrix m n (if x == v then 1 else 0) - end - | Plus _ l r => Plus tt (df_deriv l v) (df_deriv r v) - | Minus _ l r => Minus tt (df_deriv l v) (df_deriv r v) - | Times _ l r => Plus tt (Times tt l (df_deriv r v)) - (Times tt (df_deriv l v) r) - | Divide _ l r => Minus tt - (Divide tt (df_deriv l v) r) - (Divide tt (Times tt l (df_deriv r v)) - (Times tt r r)) - | Square _ e => Times tt - (Times tt (Number tt 2) e) (df_deriv e v) - | Exp _ e => Times tt (df_deriv e v) (Exp tt e) - | Log _ e => Divide tt (df_deriv e v) e - | Abs _ e => Times tt (df_deriv e v) (Sign tt e) - | Sign _ e => Number tt 0 - | PSign _ e => Number tt 0 - | Max _ l r => Divide tt - (Plus tt - (Times tt (Minus tt - (df_deriv r v) - (df_deriv l v)) - (PSign tt (Minus tt r l))) - (Plus tt (df_deriv r v) (df_deriv l v))) - (Number tt 2) - | VectorElem n _ l i => VectorElem tt (df_deriv l v) i - | MatrixElem m n _ l i j => MatrixElem tt (df_deriv l v) i j - | VectorDot n _ l r => - let ll := df_deriv l v in - let rr := df_deriv r v in - Plus tt (VectorDot tt ll r) (VectorDot tt l rr) - | VectorSum n _ l => - let ll := df_deriv l v in - VectorSum tt ll - | MatrixSum m n _ l => - let ll := df_deriv l v in - MatrixSum tt ll - | VectorScalMult n _ x r => - let xx := df_deriv x v in - let rr := df_deriv r v in - VectorPlus tt - (VectorScalMult tt xx r) - (VectorScalMult tt x rr) - | MatrixScalMult n m _ x r => - let xx := df_deriv x v in - let rr := df_deriv r v in - MatrixPlus tt - (MatrixScalMult tt xx r) (MatrixScalMult tt x rr) - | MatrixVectorMult n m _ l r => - let ll := df_deriv l v in - let rr := df_deriv r v in - VectorPlus tt (MatrixVectorMult tt ll r) - (MatrixVectorMult tt l rr) - | MatrixVectorAdd n m _ l r => - let ll := df_deriv l v in - let rr := df_deriv r v in - MatrixVectorAdd tt ll rr - | MatrixMult n m p _ l r => - let ll := df_deriv l v in - let rr := df_deriv r v in - MatrixPlus tt (MatrixMult tt ll r) (MatrixMult tt l rr) - | VectorPlus n _ l r => - let ll := df_deriv l v in - let rr := df_deriv r v in - VectorPlus tt ll rr - | VectorMinus n _ l r => - let ll := df_deriv l v in - let rr := df_deriv r v in - VectorMinus tt ll rr - | MatrixPlus n m _ l r => - let ll := df_deriv l v in - let rr := df_deriv r v in - MatrixPlus tt ll rr - | MatrixMinus n m _ l r => - let ll := df_deriv l v in - let rr := df_deriv r v in - MatrixMinus tt ll rr - | VectorApply n _ x s r => - let rr := df_deriv r v in - let ss := df_deriv s (x, DTfloat) in - DVector tt (build_vector (fun i => Times tt (VectorElem tt rr i) (df_subst ss (x, DTfloat) (VectorElem tt r i)))) - | MatrixApply n m _ x s r => - let rr := df_deriv r v in - let ss := df_deriv s (x, DTfloat) in - DMatrix tt (build_matrix (fun i j => Times tt (MatrixElem tt rr i j) (df_subst ss (x, DTfloat) (MatrixElem tt r i j)))) - | VLossfun n _ v1 v2 s l r => - let ll := df_deriv l v in - let ss := df_deriv s (v1, DTfloat) in - VectorDot tt ll - (DVector tt (build_vector (fun i => - df_subst (df_subst ss (v1, DTfloat) (VectorElem tt l i)) - (v2, DTfloat) (Number tt (vnth r i))))) - | MLossfun n m _ v1 v2 s l r => - let ll := df_deriv l v in - let ss := df_deriv s (v1, DTfloat) in - MatrixSum tt - (DMatrix tt - (build_matrix (fun i j => - (Divide tt - (Times tt (MatrixElem tt ll i j) - (df_subst (df_subst ss (v1, DTfloat) (MatrixElem tt l i j)) - (v2, DTfloat) (Number tt (mnth r i j)))) - (Number tt (FfromZ (Z.of_nat m))))))) - end). - - Definition df_gradient {T} (df:DefinedFunction UnitAnn T) (lv:list var_type) : list (DefinedFunction UnitAnn T) - := map (df_deriv df) lv. - - End deriv. - - Section eval. - - Program - Fixpoint vartlookup (l:df_env) (a:var_type) : - option (definition_function_types_interp (snd a)) - := match l with - | nil => None - | fv::os => if a == (projT1 fv) then - Some (eq_rect _ definition_function_types_interp (projT2 fv) _ _) - else vartlookup os a - end. - - Fixpoint vart_update (l:df_env) (a:var_type) (n:definition_function_types_interp (snd a)) : df_env - := match l with - | nil => (mk_env_entry a n)::nil - | fv::os => if a == (projT1 fv) then - (mk_env_entry a n)::os else fv::(vart_update os a n) - end. - - Fixpoint df_eval {T Ann} (σ:df_env) (df:DefinedFunction Ann T) : option (definition_function_types_interp T) - := match df with - | Number _ r => Some r - | Constant t _ x => Some x - | DVector n _ dfs => vectoro_to_ovector (vmap (fun x => df_eval σ x) dfs) - | DMatrix n m _ df => matrixo_to_omatrix (mmap (fun x => df_eval σ x) df) - | Var x _ => vartlookup σ x - | Plus _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (l' + r') - | _, _ => None - end - | Minus _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (l' - r') - | _, _ => None - end - | Times _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (l' * r') - | _, _ => None - end - | Divide _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (l' / r') - | _, _ => None - end - | Square _ e => - match df_eval σ e with - | Some v => Some (v * v) - | _ => None - end - | Exp _ e => - match df_eval σ e with - | Some v => Some (Fexp v) - | _ => None - end - | Log _ e => - match df_eval σ e with - | Some v => Some (Fln v) - | _ => None - end - | Abs _ e => - match df_eval σ e with - | Some v => Some (Fabs v) - | _ => None - end - | Sign _ e => - match df_eval σ e with - | Some v => Some (sign v) - | _ => None - end - | PSign _ e => - match df_eval σ e with - | Some v => Some (pos_sign v) - | _ => None - end - | Max _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (Fmax l' r') - | _, _ => None - end - | VectorElem n _ l i => - match (df_eval σ l) with - | Some l' => Some (vnth l' i) - | _ => None - end - | MatrixElem m n _ l i j => - match (df_eval σ l) with - | Some l' => Some (mnth l' i j) - | _ => None - end - | VectorDot n _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (vsum (vmap2 Fmult l' r')) - | _, _ => None - end - | VectorSum n _ l => - match df_eval σ l with - | Some l' => Some (vsum l') - | _ => None - end - | MatrixSum n m _ l => - match df_eval σ l with - | Some l' => Some (msum l') - | _ => None - end - | VectorScalMult n _ x r => - match df_eval σ x, df_eval σ r with - | Some x', Some r' => Some (vmap (fun rr => x' * rr) r') - | _, _ => None - end - | MatrixScalMult n m _ x r => - match df_eval σ x, df_eval σ r with - | Some x', Some r' => Some (mmap (fun rr => x' * rr) r') - | _, _ => None - end - | MatrixVectorMult n m _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (matrix_vector_mult l' r') - | _, _ => None - end - | MatrixVectorAdd n m _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (matrix_vector_add l' r') - | _, _ => None - end - | MatrixMult n m p _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (matrix_mult l' r') - | _, _ => None - end - | VectorPlus n _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (vmap2 Fplus l' r') - | _, _ => None - end - | VectorMinus n _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (vmap2 Fminus l' r') - | _, _ => None - end - | MatrixPlus n m _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (mmap2 Fplus l' r') - | _, _ => None - end - | MatrixMinus n m _ l r => - match df_eval σ l, df_eval σ r with - | Some l', Some r' => Some (mmap2 Fminus l' r') - | _, _ => None - end - | VectorApply n _ x s r => - match df_eval σ r with - | Some r' => vectoro_to_ovector - (vmap (fun re => - let xv := (x, DTfloat):var_type in - df_eval (cons (mk_env_entry xv re) nil) s) r') - | _ => None - end - | MatrixApply n m _ x s r => - match df_eval σ r with - | Some r' => matrixo_to_omatrix - (mmap (fun re => - let xv := (x, DTfloat):var_type in - df_eval (cons (mk_env_entry xv re) nil) s) r') - | _ => None - end - | VLossfun n _ v1 v2 s l r => - match df_eval σ l with - | Some l' => - match (vectoro_to_ovector - (build_vector (fun i => - let xv1 := (v1,DTfloat):var_type in - let xv2 := (v2,DTfloat):var_type in - df_eval (cons (mk_env_entry xv1 (vnth l' i)) - (cons (mk_env_entry xv2 (vnth r i)) nil)) s))) with - | Some vv => Some (vsum vv) - | _ => None - end - | _ => None - end - | MLossfun n m _ v1 v2 s l r => - match df_eval σ l with - | Some l' => - match (matrixo_to_omatrix - (build_matrix (fun i j => - let xv1 := (v1,DTfloat):var_type in - let xv2 := (v2,DTfloat):var_type in - df_eval (cons (mk_env_entry xv1 (mnth l' i j)) - (cons (mk_env_entry xv2 (mnth r i j)) nil)) s))) with - | Some vv => Some ((msum vv) / (FfromZ (Z.of_nat m))) - | _ => None - end - | _ => None - end - - end. - - Fixpoint df_eval_tree {T Ann} (σ:df_env) (df:DefinedFunction Ann T) : option (DefinedFunction EvalAnn T) - := match df with - | Number _ r => Some (Number r r) - | Constant t _ x => Some (Constant x x) - | DVector n _ dfs => - match vectoro_to_ovector (vmap (fun x => df_eval_tree σ x) dfs) with - | Some val => Some (DVector (vmap get_annotation val) val) - | _ => None - end - | DMatrix n m _ df => - match matrixo_to_omatrix (mmap (fun x => df_eval_tree σ x) df) with - | Some val => Some (DMatrix - (vmap (fun x => vmap get_annotation x) val) val) - | _ => None - end - | Var x _ => match vartlookup σ x with - | Some val => Some (Var x val) - | _ => None - end - | Plus _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => Some (Plus ((get_annotation l') + (get_annotation r')) - l' r') - | _, _ => None - end - | Minus _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => Some (Minus ((get_annotation l') - (get_annotation r')) - l' r') - | _, _ => None - end - | Times _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => Some (Times ((get_annotation l') * (get_annotation r')) - l' r') - | _, _ => None - end - | Divide _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => Some (Divide ((get_annotation l') / (get_annotation r')) - l' r') - | _, _ => None - end - | Square _ e => - match df_eval_tree σ e with - | Some vv => let v := get_annotation vv in Some (Square (v * v) vv) - | _ => None - end - | Exp _ e => - match df_eval_tree σ e with - | Some vv => Some (Exp (Fexp (get_annotation vv)) vv) - | _ => None - end - | Log _ e => - match df_eval_tree σ e with - | Some vv => Some (Log (Fln (get_annotation vv)) vv) - | _ => None - end - | Abs _ e => - match df_eval_tree σ e with - | Some vv => Some (Abs (Fabs (get_annotation vv)) vv) - | _ => None - end - | Sign _ e => - match df_eval_tree σ e with - | Some vv => Some (Sign (sign (get_annotation vv)) vv) - | _ => None - end - | PSign _ e => - match df_eval_tree σ e with - | Some vv => Some (PSign (pos_sign (get_annotation vv)) vv) - | _ => None - end - | Max _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => Some (Max (Fmax (get_annotation l') (get_annotation r')) - l' r') - | _, _ => None - end - | VectorElem n _ l i => - match (df_eval_tree σ l) with - | Some l' => let vl' := get_annotation l' in - Some (VectorElem (vnth vl' i) l' i) - | _ => None - end - | MatrixElem m n _ l i j => - match (df_eval_tree σ l) with - | Some l' => let vl' := get_annotation l' in - Some (MatrixElem (mnth vl' i j) l' i j) - | _ => None - end - | VectorDot n _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => let vl' := get_annotation l' in - let vr' := get_annotation r' in - Some (VectorDot (vsum (build_vector (fun i => (vnth vl' i) * (vnth vr' i)))) l' r') - | _, _ => None - end - | VectorSum n _ l => - match df_eval_tree σ l with - | Some l' => let vl' := get_annotation l' in - Some (VectorSum (vsum vl') l') - | _ => None - end - | MatrixSum n m _ l => - match df_eval_tree σ l with - | Some l' => let vl' := get_annotation l' in - Some (MatrixSum (msum vl') l') - | _ => None - end - | VectorScalMult n _ x r => - match df_eval_tree σ x, df_eval_tree σ r with - | Some x', Some r' => let vx' := get_annotation x' in - let vr' := get_annotation r' in - let vec : vector float n := vmap (fun re => vx' * re) vr' in - Some (VectorScalMult vec x' r') - | _, _ => None - end - | MatrixScalMult n m _ x r => - match df_eval_tree σ x, df_eval_tree σ r with - | Some x', Some r' => let vx' := get_annotation x' in - let vr' := get_annotation r' in - let mat : matrix float n m := mmap (fun re => vx' * re) vr' in - Some (MatrixScalMult mat x' r') - | _, _ => None - end - | MatrixVectorMult n m _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => let vl' := get_annotation l' in - let vr' := get_annotation r' in - Some (MatrixVectorMult (matrix_vector_mult vl' vr') l' r') - | _, _ => None - end - | MatrixVectorAdd n m _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => let vl' := get_annotation l' in - let vr' := get_annotation r' in - Some (MatrixVectorAdd (matrix_vector_add vl' vr') l' r') - | _, _ => None - end - | MatrixMult n m p _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => let vl' := get_annotation l' in - let vr' := get_annotation r' in - Some (MatrixMult (matrix_mult vl' vr') l' r') - | _, _ => None - end - | VectorPlus n _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => let vl' := get_annotation l' in - let vr' := get_annotation r' in - let vec : vector float n := - vmap2 Fplus vl' vr' in - Some (VectorPlus vec l' r') - | _, _ => None - end - | VectorMinus n _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => let vl' := get_annotation l' in - let vr' := get_annotation r' in - let vec : vector float n := - vmap2 Fminus vl' vr' in - Some (VectorMinus vec l' r') - | _, _ => None - end - | MatrixPlus n m _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => let vl' := get_annotation l' in - let vr' := get_annotation r' in - let mat : matrix float n m := - mmap2 Fplus vl' vr' in - Some (MatrixPlus mat l' r') - | _, _ => None - end - | MatrixMinus n m _ l r => - match df_eval_tree σ l, df_eval_tree σ r with - | Some l', Some r' => let vl' := get_annotation l' in - let vr' := get_annotation r' in - let mat : matrix float n m := - mmap2 Fminus vl' vr' in - Some (MatrixMinus mat l' r') - | _, _ => None - end - - | VectorApply n _ x s r => - match df_eval_tree σ r with - | Some r' => - let vr' := get_annotation r' in - match vectoro_to_ovector - (build_vector (fun i => - let xv := (x, DTfloat):var_type in - df_eval (cons (mk_env_entry xv (vnth vr' i)) nil) s)) with - | Some val => Some (VectorApply val x s r') - | _ => None - end - | _ => None - end - | MatrixApply n m _ x s r => - match df_eval_tree σ r with - | Some r' => - let vr' := get_annotation r' in - match matrixo_to_omatrix - (build_matrix (fun i j => - let xv := (x, DTfloat):var_type in - df_eval (cons (mk_env_entry xv (mnth vr' i j)) nil) s)) with - | Some val => Some (MatrixApply val x s r') - | _ => None - end - | _ => None - end - | VLossfun n _ v1 v2 s l r => - match df_eval_tree σ l with - | Some l' => - let vl' := get_annotation l' in - match (vectoro_to_ovector - (build_vector (fun i => - let xv1 := (v1,DTfloat):var_type in - let xv2 := (v2,DTfloat):var_type in - df_eval (cons (mk_env_entry xv1 (vnth vl' i)) - (cons (mk_env_entry xv2 (vnth r i)) nil)) s))) with - | Some vv => Some (VLossfun (vsum vv) v1 v2 s l' r) - | _ => None - end - | _ => None - end - | MLossfun n m _ v1 v2 s l r => - match df_eval_tree σ l with - | Some l' => - let vl' := get_annotation l' in - match (matrixo_to_omatrix - (build_matrix (fun i j => - let xv1 := (v1,DTfloat):var_type in - let xv2 := (v2,DTfloat):var_type in - df_eval (cons (mk_env_entry xv1 (mnth vl' i j)) - (cons (mk_env_entry xv2 (mnth r i j)) nil)) s))) with - | Some vv => Some (MLossfun ((msum vv)/(FfromZ (Z.of_nat m))) v1 v2 s l' r) - | _ => None - end - | _ => None - end - end. - - Definition eval_env_entry_type := {T:definition_function_types & (DefinedFunction UnitAnn T) & definition_function_types_interp T}. - Definition df_eval_env := list eval_env_entry_type. - - Definition mk_eval_env_entry {T} df val : eval_env_entry_type - := let P := fun t => DefinedFunction UnitAnn t in - let Q := fun t => definition_function_types_interp t in - existT2 P Q T df val. - - Definition pair_update_evals {T} (df:DefinedFunction UnitAnn T) (val:definition_function_types_interp T) (dfevals : df_eval_env) : (definition_function_types_interp T * df_eval_env) := - (val, (mk_eval_env_entry df val)::dfevals). - - Fixpoint df_evals_list {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (dfevals : df_eval_env) : option (definition_function_types_interp T * df_eval_env) - := match df with - | Number _ r => Some (pair_update_evals (Number tt r) r dfevals) - | Constant t _ x => Some (pair_update_evals (Constant tt x) x dfevals) - | DVector n _ dfs => None (*vectoro_to_ovector (fun i => df_eval σ (dfs i))*) - | DMatrix n m _ df => None (*matrixo_to_omatrix (fun i j => df_eval σ (df i j))*) - | Var x _ => - match vartlookup σ x with - | Some val => Some (pair_update_evals (Var x tt) val dfevals) - | _ => None - end - | Plus _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => Some (pair_update_evals (Plus tt l r) (l'+r') dfevals'') - | _ => None - end - | _ => None - end - | Minus _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => Some (pair_update_evals (Minus tt l r) (l'-r') dfevals'') - | _ => None - end - | _ => None - end - | Times _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => Some (pair_update_evals (Times tt l r) (l'*r') dfevals'') - | _ => None - end - | _ => None - end - | Divide _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => Some (pair_update_evals (Divide tt l r) (l'/r') dfevals'') - | _ => None - end - | _ => None - end - | Square _ e => - match df_evals_list σ e dfevals with - | Some (v, dfevals') => Some (pair_update_evals (Square tt e) (v * v) dfevals') - | _ => None - end - | Exp _ e => - match df_evals_list σ e dfevals with - | Some (v, dfevals') => Some (pair_update_evals (Exp tt e) (Fexp v) dfevals') - | _ => None - end - | Log _ e => - match df_evals_list σ e dfevals with - | Some (v, dfevals') => Some (pair_update_evals (Log tt e) (Fln v) dfevals') - | _ => None - end - | Abs _ e => - match df_evals_list σ e dfevals with - | Some (v, dfevals') => Some (pair_update_evals (Abs tt e) (Fabs v) dfevals') - | _ => None - end - | Sign _ e => - match df_evals_list σ e dfevals with - | Some (v, dfevals') => Some (pair_update_evals (Sign tt e) (sign v) dfevals') - | _ => None - end - | PSign _ e => - match df_evals_list σ e dfevals with - | Some (v, dfevals') => Some (pair_update_evals (PSign tt e) (pos_sign v) dfevals') - | _ => None - end - | Max _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => Some (pair_update_evals (Max tt l r) (Fmax l' r') dfevals'') - | _ => None - end - | _ => None - end - | VectorElem n _ l i => - match (df_evals_list σ l dfevals) with - | Some (l', dfevals') => Some (pair_update_evals (VectorElem tt l i) (vnth l' i) dfevals') - | _ => None - end - | MatrixElem m n _ l i j => - match (df_evals_list σ l dfevals) with - | Some (l', dfevals') => Some (pair_update_evals (MatrixElem tt l i j) (mnth l' i j) dfevals') - | _ => None - end - | VectorDot n _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (VectorDot tt l r) - (vsum (vmap2 Fmult l' r')) dfevals'') - | _ => None - end - | _ => None - end - | VectorSum n _ l => - match df_evals_list σ l dfevals with - | Some (l',dfevals') => Some (pair_update_evals (VectorSum tt l) (vsum l') dfevals') - | _ => None - end - | MatrixSum n m _ l => - match df_evals_list σ l dfevals with - | Some (l',dfevals') => Some (pair_update_evals (MatrixSum tt l) (msum l') dfevals') - | _ => None - end - | VectorScalMult n _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (VectorScalMult tt l r) - (vmap (fun re => l' * re) r') dfevals'') - | _ => None - end - | _ => None - end - | MatrixScalMult n m _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (MatrixScalMult tt l r) - (mmap (fun re => l' * re) r') dfevals'') - | _ => None - end - | _ => None - end - | MatrixVectorMult n m _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (MatrixVectorMult tt l r) - (matrix_vector_mult l' r') dfevals'') - | _ => None - end - | _ => None - end - | MatrixVectorAdd n m _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (MatrixVectorAdd tt l r) - (matrix_vector_add l' r') dfevals'') - | _ => None - end - | _ => None - end - | MatrixMult n m p _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (MatrixMult tt l r) (matrix_mult l' r') dfevals'') - | _ => None - end - | _ => None - end - | VectorPlus n _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (VectorPlus tt l r) - (vmap2 Fplus l' r') dfevals'') - | _ => None - end - | _ => None - end - | VectorMinus n _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (VectorMinus tt l r) - (vmap2 Fminus l' r') dfevals'') - | _ => None - end - | _ => None - end - | MatrixPlus n m _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (MatrixPlus tt l r) - (mmap2 Fplus l' r') dfevals'') - | _ => None - end - | _ => None - end - | MatrixMinus n m _ l r => - match df_evals_list σ l dfevals with - | Some (l', dfevals') => - match df_evals_list σ r dfevals' with - Some (r', dfevals'') => - Some (pair_update_evals (MatrixMinus tt l r) - (mmap2 Fminus l' r') dfevals'') - | _ => None - end - | _ => None - end - | VectorApply n _ x s r => - match df_evals_list σ r dfevals with -(* | Some r' => vectoro_to_ovector - (fun i => - let xv := (x, DTfloat):var_type in - df_eval (cons (mk_env_entry xv (r' i)) σ) s) *) - | _ => None - end - | VLossfun n _ v1 v2 s l r => - match df_evals_list σ l dfevals with -(* | Some l' => - match (vectoro_to_ovector - (fun i => - let xv1 := (v1,DTfloat):var_type in - let xv2 := (v2,DTfloat):var_type in - df_eval (cons (mk_env_entry xv1 (l' i)) - (cons (mk_env_entry xv2 (r i)) σ)) s)) with - | Some vv => Some (vsum vv) - | _ => None - end *) - | _ => None - end - | _ => None - end. - -(* - Program - Fixpoint evalslookup {T} (l:df_eval_env) (df:DefinedFunction UnitAnn T) : - option (definition_function_types_interp T) - := match l with - | nil => None - | fv::os => if T == (projT1 (sigT_of_sigT2 fv)) then - if df == (projT2 (sigT_of_sigT2 fv)) then - Some (eq_rect _ definition_function_types_interp (projT3 fv) _ _) - else evalslookup os df - else evalslookup os df - end. -*) - Definition df_eval_symbolic_gradient {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (lv:list var_type) : option (list (definition_function_types_interp T)) - := listo_to_olist (map (df_eval σ) (df_gradient df lv)). - - End eval. - - Section isderiv. - - Context (σ:df_env). - Context (v:SubVar). -(* - Inductive is_deriv : DefinedFunction -> float -> Prop - := - | is_deriv_Number (x : float) : is_deriv (Number x) 0 - | is_deriv_Var_eq : is_deriv (Var v) 1 - | is_deriv_Var_neq (sv : SubVar) : sv <> v -> is_deriv (Var sv) 0 - | is_deriv_Plus l l' r r' : - is_deriv l l' -> - is_deriv r r' -> - is_deriv (Plus l r) (l' + r') - | is_deriv_Minus l l' r r' : - is_deriv l l' -> - is_deriv r r' -> - is_deriv (Minus l r) (l' - r') - | is_deriv_Times l le l' r re r' : - df_eval σ l = Some le -> - is_deriv l l' -> - df_eval σ r = Some re -> - is_deriv r r' -> - is_deriv (Times l r) ((le * r') + (l' * re)) - | is_deriv_Divide l le l' r re r' : - df_eval σ l = Some le -> - is_deriv l l' -> - df_eval σ r = Some re -> - is_deriv r r' -> - is_deriv (Times l r) - (((l' * re ) - (le * r')) - / (re * re)) - | is_deriv_Exp e ee e' : - df_eval σ e = Some ee -> - is_deriv e e' -> - is_deriv (Exp e) (e' * (Fexp ee)) - | is_deriv_Log e ee e' : - df_eval σ e = Some ee -> - is_deriv e e' -> - is_deriv (Exp e) (e' / ee) - | is_deriv_Abs e ee e' : - df_eval σ e = Some ee -> - is_deriv e e' -> is_deriv (Abs e) (e' * (sign ee)) - | is_deriv_Sign (e : DefinedFunction) : - is_deriv (Sign e) 0 - | is_deriv_PSign (e : DefinedFunction) : - is_deriv (PSign e) 0 - | is_deriv_Max_l l le l' re r : - df_eval σ l = Some le -> - df_eval σ r = Some re -> - (le > re) = true -> - is_deriv l l' -> - is_deriv (Max l r) l' - | is_deriv_Max_r l le r re r' : - df_eval σ l = Some le -> - df_eval σ r = Some re -> - (re >= le) = true -> - is_deriv r r' -> - is_deriv (Max l r) r'. - (* - | is_deriv_Max_eq l l' ee r r' : - df_eval σ l = Some ee -> - df_eval σ r = Some ee -> - is_deriv l l' -> - is_deriv r r' -> - is_deriv (Max l r) ((l' + r')/2) *) - -*) - End isderiv. - - Section deriv2. - - Fixpoint df_eval_deriv {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v:var_type) : option (definition_function_types_interp T) - := (match df with - | Number _ _ => Some 0 - | Constant t _ x => Some - match t return definition_function_types_interp t with - | DTfloat => 0 - | DTVector n => ConstVector n 0 - | DTMatrix m n => ConstMatrix m n 0 - end - | DVector n _ dfs => vectoro_to_ovector (vmap (fun xe => df_eval_deriv σ xe v) dfs) - | DMatrix n m _ df => matrixo_to_omatrix (mmap (fun xe => df_eval_deriv σ xe v) df) - | Var x _ => Some (let t:=snd x in - match t return definition_function_types_interp t with - | DTfloat => if x == v then 1 else 0 - | DTVector n => ConstVector n (if x == v then 1 else 0) - | DTMatrix m n => ConstMatrix m n (if x == v then 1 else 0) - end) - | Plus _ l r => - match df_eval_deriv σ l v, df_eval_deriv σ r v with - | Some le, Some lr => Some (le + lr) - | _, _ => None - end - | Minus _ l r => - match df_eval_deriv σ l v, df_eval_deriv σ r v with - | Some le, Some lr => Some (le - lr) - | _, _ => None - end - | Times _ l r => - match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with - | Some le, Some ld, Some re, Some rd => - Some (le * rd + - (ld * re)) - | _, _, _, _ => None - end - | Divide _ l r => - match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with - | Some le, Some ld, Some re, Some rd => - Some ((ld / re) - ((le * rd) / (re * re))) - | _, _, _, _ => None - end - | Square _ e => - match df_eval σ e, df_eval_deriv σ e v with - | Some ee, Some ed => Some (2 * ee * ed) - | _, _ => None - end - | Exp _ e => - match df_eval σ e, df_eval_deriv σ e v with - | Some ee, Some ed => Some (ed * Fexp ee) - | _, _ => None - end - | Log _ e => - match df_eval σ e, df_eval_deriv σ e v with - | Some ee, Some ed => Some (ed / ee) - | _, _ => None - end - | Abs _ e => - match df_eval σ e, df_eval_deriv σ e v with - | Some ee, Some ed => Some (ed * (sign ee)) - | _, _ => None - end - | Sign _ e => - match df_eval_deriv σ e v with - | Some _ => Some 0 - | None => None - end - | PSign _ e => - match df_eval_deriv σ e v with - | Some _ => Some 0 - | None => None - end - | Max _ l r => - match df_eval σ l, df_eval σ r with - | Some le, Some re => - if le <= re then df_eval_deriv σ r v else df_eval_deriv σ l v - | _, _ => None - end - | VectorElem n _ l i => - match (df_eval_deriv σ l v) with - | Some l' => Some (vnth l' i) - | _ => None - end - | MatrixElem m n _ l i j => - match (df_eval_deriv σ l v) with - | Some l' => Some (mnth l' i j) - | _ => None - end - | VectorDot n _ l r => - match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with - | Some le, Some ld, Some re, Some rd => - Some (Fplus (vsum (vmap2 Fmult le rd)) (vsum (vmap2 Fmult ld re))) - | _, _, _, _ => None - end - | VectorSum n _ l => - match df_eval_deriv σ l v with - | Some ld => - Some (vsum ld) - | _ => None - end - | MatrixSum n m _ l => - match df_eval_deriv σ l v with - | Some ld => - Some (msum ld) - | _ => None - end - | VectorScalMult n _ x r => - match df_eval σ x, df_eval_deriv σ x v, df_eval σ r, df_eval_deriv σ r v with - | Some xe, Some xd, Some re, Some rd => - Some (vmap2 (fun rd1 re1 => xe * rd1 + xd * re1) rd re) - | _, _, _, _ => None - end - | MatrixScalMult n m _ x r => - match df_eval σ x, df_eval_deriv σ x v, df_eval σ r, df_eval_deriv σ r v with - | Some xe, Some xd, Some re, Some rd => - Some (mmap2 (fun rd1 re1 => xe * rd1 + xd * re1) rd re) - | _, _, _, _ => None - end - | MatrixVectorMult n m _ l r => - match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with - | Some le, Some ld, Some re, Some rd => - Some (build_vector (fun i => vsum (vmap4 (fun lei rde ldi ree => lei * rde + ldi * ree) - (vnth le i) rd (vnth ld i) re))) - | _, _, _, _ => None - end - | MatrixVectorAdd n m _ l r => - match df_eval_deriv σ l v, df_eval_deriv σ r v with - | Some ld, Some rd => - Some (build_matrix (fun i j => (mnth ld i j) + (vnth rd i))) - | _, _ => None - end - | MatrixMult n m p _ l r => - match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with - | Some le, Some ld, Some re, Some rd => - Some (build_matrix (fun i k => vsum (build_vector (fun j => (mnth le i j)*(mnth rd j k) + (mnth ld i j)*(mnth re j k))))) - | _, _, _, _ => None - end - | VectorPlus n _ l r => - match df_eval_deriv σ l v, df_eval_deriv σ r v with - | Some l', Some r' => Some (vmap2 Fplus l' r') - | _, _ => None - end - | VectorMinus n _ l r => - match df_eval_deriv σ l v, df_eval_deriv σ r v with - | Some l', Some r' => Some (vmap2 Fminus l' r') - | _, _ => None - end - | MatrixPlus n m _ l r => - match df_eval_deriv σ l v, df_eval_deriv σ r v with - | Some l', Some r' => Some (mmap2 Fplus l' r') - | _, _ => None - end - | MatrixMinus n m _ l r => - match df_eval_deriv σ l v, df_eval_deriv σ r v with - | Some l', Some r' => Some (mmap2 Fminus l' r') - | _, _ => None - end - | VectorApply n _ x s r => - match df_eval σ r, df_eval_deriv σ r v with - | Some re, Some rd => - vectoro_to_ovector - (build_vector (fun i => - let xv := (x, DTfloat):var_type in - match df_eval_deriv (cons (mk_env_entry xv (vnth re i)) nil) s xv with - | Some sd => Some ((vnth rd i) * sd) - | _ => None - end)) - | _, _ => None - end - | MatrixApply n m _ x s r => - match df_eval σ r, df_eval_deriv σ r v with - | Some re, Some rd => - matrixo_to_omatrix - (build_matrix (fun i j => - let xv := (x, DTfloat):var_type in - match df_eval_deriv (cons (mk_env_entry xv (mnth re i j)) nil) s xv with - | Some sd => Some ((mnth rd i j) * sd) - | _ => None - end)) - | _, _ => None - end - | VLossfun n _ v1 v2 s l r => - match df_eval σ l, df_eval_deriv σ l v with - | Some le, Some ld => - match (vectoro_to_ovector - (build_vector (fun i => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - match df_eval_deriv (cons (mk_env_entry xv1 (vnth le i)) - (cons (mk_env_entry xv2 (vnth r i)) nil)) s xv1 with - | Some sd => Some ((vnth ld i) * sd) - | _ => None - end))) with - | Some vv => Some (vsum vv) - | _ => None - end - | _, _ => None - end - | MLossfun n m _ v1 v2 s l r => - match df_eval σ l, df_eval_deriv σ l v with - | Some le, Some ld => - match (matrixo_to_omatrix - (build_matrix (fun i j => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - match df_eval_deriv (cons (mk_env_entry xv1 (mnth le i j)) - (cons (mk_env_entry xv2 (mnth r i j)) nil)) s xv1 with - | Some sd => Some ((mnth ld i j) * sd) - | _ => None - end))) with - | Some vv => Some ((msum vv)/(FfromZ (Z.of_nat m))) - | _ => None - end - | _, _ => None - end - end). - - Definition mk_genvar_env (s:SubVar) := mk_env_entry (s, DTfloat) (FfromZ (Z.of_nat 1)) :: nil. - - (* the v environment below pairs variables with their derivatives *) - (* in some sense this is giving a directional derivative defined by v *) - Fixpoint df_eval_deriv_genvar {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v:df_env) : option (definition_function_types_interp T) - := (match df with - | Number _ _ => Some 0 - | Constant t _ x => Some - match t return definition_function_types_interp t with - | DTfloat => 0 - | DTVector n => ConstVector n 0 - | DTMatrix m n => ConstMatrix m n 0 - end - | DVector n _ dfs => vectoro_to_ovector (vmap (fun df1 => df_eval_deriv_genvar σ df1 v) dfs) - | DMatrix n m _ df => matrixo_to_omatrix (mmap (fun df1 => df_eval_deriv_genvar σ df1 v) df) - | Var x _ => Some ( - match vartlookup v x with - | Some val => val - | _ => - match (snd x) with - | DTfloat => 0 - | DTVector n => ConstVector n 0 - | DTMatrix m n => ConstMatrix m n 0 - end - end) - | Plus _ l r => - match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with - | Some le, Some lr => Some (le + lr) - | _, _ => None - end - | Minus _ l r => - match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with - | Some le, Some lr => Some (le - lr) - | _, _ => None - end - | Times _ l r => - match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with - | Some le, Some ld, Some re, Some rd => - Some (le * rd + - (ld * re)) - | _, _, _, _ => None - end - | Divide _ l r => - match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with - | Some le, Some ld, Some re, Some rd => - Some ((ld / re) - ((le * rd) / (re * re))) - | _, _, _, _ => None - end - | Square _ e => - match df_eval σ e, df_eval_deriv_genvar σ e v with - | Some ee, Some ed => Some (2 * ee * ed) - | _, _ => None - end - | Exp _ e => - match df_eval σ e, df_eval_deriv_genvar σ e v with - | Some ee, Some ed => Some (ed * Fexp ee) - | _, _ => None - end - | Log _ e => - match df_eval σ e, df_eval_deriv_genvar σ e v with - | Some ee, Some ed => Some (ed / ee) - | _, _ => None - end - | Abs _ e => - match df_eval σ e, df_eval_deriv_genvar σ e v with - | Some ee, Some ed => Some (ed * (sign ee)) - | _, _ => None - end - | Sign _ e => - match df_eval_deriv_genvar σ e v with - | Some _ => Some 0 - | None => None - end - | PSign _ e => - match df_eval_deriv_genvar σ e v with - | Some _ => Some 0 - | None => None - end - | Max _ l r => - match df_eval σ l, df_eval σ r with - | Some le, Some re => - if le <= re then df_eval_deriv_genvar σ r v else df_eval_deriv_genvar σ l v - | _, _ => None - end - | VectorElem n _ l i => - match (df_eval_deriv_genvar σ l v) with - | Some l' => Some (vnth l' i) - | _ => None - end - | MatrixElem m n _ l i j => - match (df_eval_deriv_genvar σ l v) with - | Some l' => Some (mnth l' i j) - | _ => None - end - | VectorDot n _ l r => - match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with - | Some le, Some ld, Some re, Some rd => - Some (vsum (vmap4 (fun le1 rd1 ld1 re1 => le1 * rd1 + ld1 * re1) le rd ld re)) - | _, _, _, _ => None - end - | VectorSum n _ l => - match df_eval_deriv_genvar σ l v with - | Some ld => - Some (vsum ld) - | _ => None - end - | MatrixSum n m _ l => - match df_eval_deriv_genvar σ l v with - | Some ld => - Some (msum ld) - | _ => None - end - | VectorScalMult n _ x r => - match df_eval σ x, df_eval_deriv_genvar σ x v, df_eval σ r, df_eval_deriv_genvar σ r v with - | Some xe, Some xd, Some re, Some rd => - Some (vmap2 (fun rd1 re1 => xe * rd1 + xd * re1) rd re) - | _, _, _, _ => None - end - | MatrixScalMult n m _ x r => - match df_eval σ x, df_eval_deriv_genvar σ x v, df_eval σ r, df_eval_deriv_genvar σ r v with - | Some xe, Some xd, Some re, Some rd => - Some (mmap2 (fun rd1 re1 => xe * rd1 + xd * re1) rd re) - | _, _, _, _ => None - end - | MatrixVectorMult n m _ l r => - match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with - | Some le, Some ld, Some re, Some rd => - Some (build_vector (fun i => vsum (vmap4 (fun lei rde ldi ree => lei * rde + ldi * ree) - (vnth le i) rd (vnth ld i) re))) - | _, _, _, _ => None - end - | MatrixVectorAdd n m _ l r => - match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with - | Some ld, Some rd => - Some (build_matrix (fun i j => (mnth ld i j) + (vnth rd i))) - | _, _ => None - end - | MatrixMult n m p _ l r => - match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with - | Some le, Some ld, Some re, Some rd => - Some (build_matrix (fun i k => vsum (build_vector (fun j => (mnth le i j)*(mnth rd j k) + (mnth ld i j)*(mnth re j k))))) - | _, _, _, _ => None - end - | VectorPlus n _ l r => - match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with - | Some l', Some r' => Some (vmap2 Fplus l' r') - | _, _ => None - end - | VectorMinus n _ l r => - match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with - | Some l', Some r' => Some (vmap2 Fminus l' r') - | _, _ => None - end - | MatrixPlus n m _ l r => - match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with - | Some l', Some r' => Some (mmap2 Fplus l' r') - | _, _ => None - end - | MatrixMinus n m _ l r => - match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with - | Some l', Some r' => Some (mmap2 Fminus l' r') - | _, _ => None - end - | VectorApply n _ x s r => - match df_eval σ r, df_eval_deriv_genvar σ r v with - | Some re, Some rd => - vectoro_to_ovector - (build_vector (fun i => - let xv := (x, DTfloat):var_type in - match df_eval_deriv_genvar (mk_env_entry xv (vnth re i) :: nil) s - (mk_genvar_env x) with - | Some sd => Some ((vnth rd i) * sd) - | _ => None - end)) - | _, _ => None - end - | MatrixApply n m _ x s r => - match df_eval σ r, df_eval_deriv_genvar σ r v with - | Some re, Some rd => - matrixo_to_omatrix - (build_matrix (fun i j => - let xv := (x, DTfloat):var_type in - match df_eval_deriv_genvar (mk_env_entry xv (mnth re i j) :: nil) s - (mk_genvar_env x) with - | Some sd => Some ((mnth rd i j) * sd) - | _ => None - end)) - | _, _ => None - end - | VLossfun n _ v1 v2 s l r => - match df_eval σ l, df_eval_deriv_genvar σ l v with - | Some le, Some ld => - match (vectoro_to_ovector - (build_vector (fun i => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - match df_eval_deriv_genvar ( mk_env_entry xv1 (vnth le i) :: - mk_env_entry xv2 (vnth r i) :: nil) s - (mk_genvar_env v1) with - | Some sd => Some ((vnth ld i) * sd) - | _ => None - end))) with - | Some vv => Some (vsum vv) - | _ => None - end - | _, _ => None - end - | MLossfun n m _ v1 v2 s l r => - match df_eval σ l, df_eval_deriv_genvar σ l v with - | Some le, Some ld => - match (matrixo_to_omatrix - (build_matrix (fun i j => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - match df_eval_deriv_genvar ( mk_env_entry xv1 (mnth le i j) :: - mk_env_entry xv2 (mnth r i j) :: nil) s - ( mk_genvar_env v1) with - | Some sd => Some ((mnth ld i j) * sd) - | _ => None - end))) with - | Some vv => Some ((msum vv)/(FfromZ (Z.of_nat m))) - | _ => None - end - | _, _ => None - end - end). - - - Definition definition_function_types_interp_prod (vart dft:definition_function_types) : Type - := match vart with - | DTfloat => definition_function_types_interp dft - | DTVector n => Vector (definition_function_types_interp dft) n - | DTMatrix m n => Matrix (definition_function_types_interp dft) m n - end. - - - Definition UnitVector (n:nat) (j : {n':nat | (n' < n)%nat}) : vector float n := - build_vector (fun i => if (proj1_sig i) == (proj1_sig j) then 1 else 0). - - Definition UnitMatrix (n m: nat) - (i : {n':nat | (n' < n)%nat}) - (j : {m':nat | (m' < m)%nat}) : matrix float n m := - build_matrix (fun a b => if (proj1_sig a) == (proj1_sig i) then - (if (proj1_sig b) == (proj1_sig j) then 1 else 0) - else 0). - - Definition const_env (v : var_type) : df_env - := match (snd v) with - | DTfloat => ((mk_env_entry (fst v, DTfloat) 0)::nil) - | DTVector n => ((mk_env_entry (fst v, DTVector n) (ConstVector n 0))::nil) - | DTMatrix m n => ((mk_env_entry (fst v, DTMatrix m n) (ConstMatrix m n 0))::nil) - end. - - Fixpoint df_eval_tree_deriv {T} (σ:df_env) (df:DefinedFunction EvalAnn T) (v:var_type) : option (definition_function_types_interp T) - := (match df with - | Number _ _ => Some 0 - | Constant t _ x => Some - match t return definition_function_types_interp t with - | DTfloat => 0 - | DTVector n => ConstVector n 0 - | DTMatrix m n => ConstMatrix m n 0 - end - | DVector n _ dfs => vectoro_to_ovector (vmap (fun e => df_eval_tree_deriv σ e v) dfs) - | DMatrix n m _ df => matrixo_to_omatrix (mmap (fun e => df_eval_tree_deriv σ e v) df) - | Var x _ => Some (let t:=snd x in - match t return definition_function_types_interp t with - | DTfloat => if x == v then 1 else 0 - | DTVector n => ConstVector n (if x == v then 1 else 0) - | DTMatrix m n => ConstMatrix m n (if x == v then 1 else 0) - end) - | Plus _ l r => - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some le, Some lr => Some (le + lr) - | _, _ => None - end - | Minus _ l r => - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some le, Some lr => Some (le - lr) - | _, _ => None - end - | Times _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some ld, Some rd => - Some (le * rd + - (ld * re)) - | _, _ => None - end - | Divide _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some ld, Some rd => - Some ((ld / re) - ((le * rd) / (re * re))) - | _, _ => None - end - | Square _ e => - let ee := get_annotation e in - match df_eval_tree_deriv σ e v with - | Some ed => Some (2 * ee * ed) - | _ => None - end - | Exp _ e => - let ee := get_annotation e in - match df_eval_tree_deriv σ e v with - | Some ed => Some (ed * Fexp ee) - | _ => None - end - | Log _ e => - let ee := get_annotation e in - match df_eval_tree_deriv σ e v with - | Some ed => Some (ed / ee) - | _ => None - end - | Abs _ e => - let ee := get_annotation e in - match df_eval_tree_deriv σ e v with - | Some ed => Some (ed * (sign ee)) - | _ => None - end - | Sign _ e => - match df_eval_tree_deriv σ e v with - | Some _ => Some 0 - | None => None - end - | PSign _ e => - match df_eval_tree_deriv σ e v with - | Some _ => Some 0 - | None => None - end - | Max _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - if le <= re then df_eval_tree_deriv σ r v else df_eval_tree_deriv σ l v - | VectorElem n _ l i => - match (df_eval_tree_deriv σ l v) with - | Some l' => Some (vnth l' i) - | _ => None - end - | MatrixElem m n _ l i j => - match (df_eval_tree_deriv σ l v) with - | Some l' => Some (mnth l' i j) - | _ => None - end - | VectorDot n _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some ld, Some rd => - Some (Fplus (vsum (vmap2 Fmult le rd)) (vsum (vmap2 Fmult ld re))) - | _, _ => None - end - | VectorSum n _ l => - match df_eval_tree_deriv σ l v with - | Some ld => - Some (vsum ld) - | _ => None - end - | MatrixSum n m _ l => - match df_eval_tree_deriv σ l v with - | Some ld => - Some (msum ld) - | _ => None - end - | VectorScalMult n _ x r => - let '(xe,re) := (get_annotation x, get_annotation r) in - match df_eval_tree_deriv σ x v, df_eval_tree_deriv σ r v with - | Some xd, Some rd => - Some (vmap2 (fun rd1 re1 => xe * rd1 + xd * re1) rd re) - | _, _ => None - end - | MatrixScalMult n m _ x r => - let '(xe,re) := (get_annotation x, get_annotation r) in - match df_eval_tree_deriv σ x v, df_eval_tree_deriv σ r v with - | Some xd, Some rd => - Some (mmap2 (fun rd1 re1 => xe * rd1 + xd * re1) rd re) - | _, _ => None - end - | MatrixVectorMult n m _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some ld, Some rd => - Some (build_vector (fun i => vsum (vmap4 (fun lei rde ldi ree => lei * rde + ldi * ree) - (vnth le i) rd (vnth ld i) re))) - | _, _ => None - end - | MatrixVectorAdd n m _ l r => - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some ld, Some rd => - Some (build_matrix (fun i j => (mnth ld i j) + (vnth rd i))) - | _, _ => None - end - | MatrixMult n m p _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some ld, Some rd => - Some (build_matrix (fun i k => vsum (build_vector (fun j => (mnth le i j)*(mnth rd j k) + (mnth ld i j)*(mnth re j k))))) - | _, _ => None - end - | VectorPlus n _ l r => - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some l', Some r' => Some (vmap2 Fplus l' r') - | _, _ => None - end - | VectorMinus n _ l r => - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some l', Some r' => Some (vmap2 Fminus l' r') - | _, _ => None - end - | MatrixPlus n m _ l r => - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some l', Some r' => Some (mmap2 Fplus l' r') - | _, _ => None - end - | MatrixMinus n m _ l r => - match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with - | Some l', Some r' => Some (mmap2 Fminus l' r') - | _, _ => None - end - | VectorApply n _ x s r => - let re := get_annotation r in - match df_eval_tree_deriv σ r v with - | Some rd => - vectoro_to_ovector - (build_vector (fun i => - let xv := (x, DTfloat):var_type in - match df_eval_deriv (cons (mk_env_entry xv (vnth re i)) nil ) s xv with - | Some sd => Some ((vnth rd i) * sd) - | _ => None - end)) - | _ => None - end - | MatrixApply n m _ x s r => - let re := get_annotation r in - match df_eval_tree_deriv σ r v with - | Some rd => - matrixo_to_omatrix - (build_matrix (fun i j => - let xv := (x, DTfloat):var_type in - match df_eval_deriv (cons (mk_env_entry xv (mnth re i j)) nil ) s xv with - | Some sd => Some ((mnth rd i j) * sd) - | _ => None - end)) - | _ => None - end - | VLossfun n _ v1 v2 s l r => - let le := get_annotation l in - match df_eval_tree_deriv σ l v with - | Some ld => - match (vectoro_to_ovector - (build_vector (fun i => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - match df_eval_deriv (cons (mk_env_entry xv1 (vnth le i)) - (cons (mk_env_entry xv2 (vnth r i)) nil)) s xv1 with - | Some sd => Some ((vnth ld i) * sd) - | _ => None - end))) with - | Some vv => Some (vsum vv) - | _ => None - end - | _ => None - end - | MLossfun n m _ v1 v2 s l r => - let le := get_annotation l in - match df_eval_tree_deriv σ l v with - | Some ld => - match (matrixo_to_omatrix - (build_matrix (fun i j => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - match df_eval_deriv (cons (mk_env_entry xv1 (mnth le i j)) - (cons (mk_env_entry xv2 (mnth r i j)) nil )) s xv1 with - | Some sd => Some ((mnth ld i j) * sd) - | _ => None - end))) with - | Some vv => Some ((msum vv) / (FfromZ (Z.of_nat m))) - | _ => None - end - | _ => None - end - end). - - Fixpoint df_eval_tree_deriv_genvar {T} (σ:df_env) (df:DefinedFunction EvalAnn T) (v:df_env) : option (definition_function_types_interp T) - := (match df with - | Number _ _ => Some 0 - | Constant t _ x => Some - match t return definition_function_types_interp t with - | DTfloat => 0 - | DTVector n => ConstVector n 0 - | DTMatrix m n => ConstMatrix m n 0 - end - | DVector n _ dfs => vectoro_to_ovector (vmap (fun e => df_eval_tree_deriv_genvar σ e v) dfs) - | DMatrix n m _ df => matrixo_to_omatrix (mmap (fun e => df_eval_tree_deriv_genvar σ e v) df) - | Var x _ => Some ( - match vartlookup v x with - | Some val => val - | _ => - match (snd x) with - | DTfloat => 0 - | DTVector n => ConstVector n 0 - | DTMatrix m n => ConstMatrix m n 0 - end - end) - | Plus _ l r => - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some le, Some lr => Some (le + lr) - | _, _ => None - end - | Minus _ l r => - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some le, Some lr => Some (le - lr) - | _, _ => None - end - | Times _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some ld, Some rd => - Some (le * rd + - (ld * re)) - | _, _ => None - end - | Divide _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some ld, Some rd => - Some ((ld / re) - ((le * rd) / (re * re))) - | _, _ => None - end - | Square _ e => - let ee := get_annotation e in - match df_eval_tree_deriv_genvar σ e v with - | Some ed => Some (2 * ee * ed) - | _ => None - end - | Exp _ e => - let ee := get_annotation e in - match df_eval_tree_deriv_genvar σ e v with - | Some ed => Some (ed * Fexp ee) - | _ => None - end - | Log _ e => - let ee := get_annotation e in - match df_eval_tree_deriv_genvar σ e v with - | Some ed => Some (ed / ee) - | _ => None - end - | Abs _ e => - let ee := get_annotation e in - match df_eval_tree_deriv_genvar σ e v with - | Some ed => Some (ed * (sign ee)) - | _ => None - end - | Sign _ e => - match df_eval_tree_deriv_genvar σ e v with - | Some _ => Some 0 - | None => None - end - | PSign _ e => - match df_eval_tree_deriv_genvar σ e v with - | Some _ => Some 0 - | None => None - end - | Max _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - if le <= re then df_eval_tree_deriv_genvar σ r v else df_eval_tree_deriv_genvar σ l v - | VectorElem n _ l i => - match (df_eval_tree_deriv_genvar σ l v) with - | Some l' => Some (vnth l' i) - | _ => None - end - | MatrixElem m n _ l i j => - match (df_eval_tree_deriv_genvar σ l v) with - | Some l' => Some (mnth l' i j) - | _ => None - end - | VectorDot n _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some ld, Some rd => - Some (vsum (vmap4 (fun le1 rd1 ld1 re1 => le1 * rd1 + ld1 * re1) le rd ld re)) - | _, _ => None - end - | VectorSum n _ l => - match df_eval_tree_deriv_genvar σ l v with - | Some ld => - Some (vsum ld) - | _ => None - end - | MatrixSum n m _ l => - match df_eval_tree_deriv_genvar σ l v with - | Some ld => - Some (msum ld) - | _ => None - end - | VectorScalMult n _ x r => - let '(xe,re) := (get_annotation x, get_annotation r) in - match df_eval_tree_deriv_genvar σ x v, df_eval_tree_deriv_genvar σ r v with - | Some xd, Some rd => - Some (vmap2 (fun rd1 re1 => xe * rd1 + xd * re1) rd re) - | _, _ => None - end - | MatrixScalMult n m _ x r => - let '(xe,re) := (get_annotation x, get_annotation r) in - match df_eval_tree_deriv_genvar σ x v, df_eval_tree_deriv_genvar σ r v with - | Some xd, Some rd => - Some (mmap2 (fun rd1 re1 => xe * rd1 + xd * re1) rd re) - | _, _ => None - end - | MatrixVectorMult n m _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some ld, Some rd => - Some (build_vector (fun i => vsum (vmap4 (fun lei rde ldi ree => lei * rde + ldi * ree) - (vnth le i) rd (vnth ld i) re))) - | _, _ => None - end - | MatrixVectorAdd n m _ l r => - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some ld, Some rd => - Some (build_matrix (fun i j => (mnth ld i j) + (vnth rd i))) - | _, _ => None - end - | MatrixMult n m p _ l r => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some ld, Some rd => - Some (build_matrix (fun i k => vsum (build_vector (fun j => (mnth le i j)*(mnth rd j k) + (mnth ld i j)*(mnth re j k))))) - | _, _ => None - end - | VectorPlus n _ l r => - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some l', Some r' => Some (vmap2 Fplus l' r') - | _, _ => None - end - | VectorMinus n _ l r => - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some l', Some r' => Some (vmap2 Fminus l' r') - | _, _ => None - end - | MatrixPlus n m _ l r => - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some l', Some r' => Some (mmap2 Fplus l' r') - | _, _ => None - end - | MatrixMinus n m _ l r => - match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with - | Some l', Some r' => Some (mmap2 Fminus l' r') - | _, _ => None - end - | VectorApply n _ x s r => - let re := get_annotation r in - match df_eval_tree_deriv_genvar σ r v with - | Some rd => - vectoro_to_ovector - (build_vector (fun i => - let xv := (x, DTfloat):var_type in - match df_eval_deriv_genvar (cons (mk_env_entry xv (vnth re i)) nil ) s - (mk_genvar_env x) with - | Some sd => Some ((vnth rd i) * sd) - | _ => None - end)) - | _ => None - end - | MatrixApply n m _ x s r => - let re := get_annotation r in - match df_eval_tree_deriv_genvar σ r v with - | Some rd => - matrixo_to_omatrix - (build_matrix (fun i j => - let xv := (x, DTfloat):var_type in - match df_eval_deriv_genvar (cons (mk_env_entry xv (mnth re i j)) nil) s - (mk_genvar_env x) with - | Some sd => Some ((mnth rd i j) * sd) - | _ => None - end)) - | _ => None - end - | VLossfun n _ v1 v2 s l r => - let le := get_annotation l in - match df_eval_tree_deriv_genvar σ l v with - | Some ld => - match (vectoro_to_ovector - (build_vector (fun i => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - match df_eval_deriv_genvar (cons (mk_env_entry xv1 (vnth le i)) - (cons (mk_env_entry xv2 (vnth r i)) nil )) s - (mk_genvar_env v1) with - | Some sd => Some ((vnth ld i) * sd) - | _ => None - end))) with - | Some vv => Some (vsum vv) - | _ => None - end - | _ => None - end - | MLossfun n m _ v1 v2 s l r => - let le := get_annotation l in - match df_eval_tree_deriv_genvar σ l v with - | Some ld => - match (matrixo_to_omatrix - (build_matrix (fun i j => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - match df_eval_deriv_genvar (cons (mk_env_entry xv1 (mnth le i j)) - (cons (mk_env_entry xv2 (mnth r i j)) nil)) s - (mk_genvar_env v1) with - | Some sd => Some ((mnth ld i j) * sd) - | _ => None - end))) with - | Some vv => Some ((msum vv) / (FfromZ (Z.of_nat m))) - | _ => None - end - | _ => None - end - end). - - Definition vector_env_iter {n} {A} (f: A -> df_env -> option df_env) - (env: df_env) (v : vector A n) : option df_env := - Vector.fold_right (fun a oenv => match oenv with - | Some env => f a env - | _ => None - end) - v (Some env). - - Fixpoint list_env_iter {A} (f: A -> df_env -> option df_env) - (oenv:option df_env) (l: list A) : option df_env := - match oenv, l with - | Some env, x :: l' => list_env_iter f (f x env) l' - | _, _ => oenv - end. - - Lemma list_env_iter_none {A} (f: A -> df_env -> option df_env) (l: list A) : - list_env_iter f None l = None. - Proof. - induction l. - now simpl. - now simpl. - Qed. - - Lemma list_env_iter_env_not_none {A} (f: A -> df_env -> option df_env) - (oenv : option df_env) (l: list A): - list_env_iter f oenv l <> None -> oenv <> None. - Proof. - intros. - destruct oenv. - + discriminate. - + rewrite list_env_iter_none in H. - tauto. - Qed. - - Lemma list_env_iter_app {A} (f: A -> df_env -> option df_env) - (oenv:option df_env) (l1 l2: list A) : - list_env_iter f oenv (l1++l2) = - list_env_iter f (list_env_iter f oenv l1) l2. - Proof. - revert l2 oenv. - induction l1; intros l2 oenv; simpl. - - now destruct oenv. - - destruct oenv. - + auto. - + now rewrite list_env_iter_none. - Qed. - - - Lemma list_env_iter_ext {A} f1 f2 oenv (l:list A) : - (forall x a, In x l -> f1 x a = f2 x a) -> - list_env_iter f1 oenv l = list_env_iter f2 oenv l. - Proof. - intros fa. - revert oenv. - induction l; intros oenv; intros; simpl - ; match_destr. - rewrite fa; simpl; intuition. - Qed. - - - Definition two_vector_env_iter {n} {A B} (f: A -> B -> df_env -> option df_env) - (env: df_env) (v: vector A n) (w: vector B n) : option df_env := - vector_env_iter (fun '(a,b) env => f a b env) env - (vector_zip v w). - - Fixpoint two_vector_env_iter_alt2 {n1 n2} {A B} (f: A -> B -> df_env -> option df_env) - (oenv: option df_env) (v: vector A n1) (w: vector B n2) : option df_env := - match oenv,v,w with - | Some env, Vector.cons vx _ v', Vector.cons wx _ w' => two_vector_env_iter_alt2 f (f vx wx env) v' w' - | oenv',_,_ => oenv' - end. - - Definition two_vector_env_iter_alt {n} {A B} (f: A -> B -> df_env -> option df_env) - (env: df_env) (v: vector A n) (w: vector B n) : option df_env := - list_env_iter (fun '(a,b) env => f a b env) (Some env) (combine (Vector.to_list v) (Vector.to_list w)). - - Definition matrix_env_iter {m n} {A} (f: A -> df_env -> option df_env) - (env: option df_env) (mat : matrix A m n) : option df_env := - Vector.fold_right - (fun vec oenv => - Vector.fold_right (fun a oenv => match oenv with - | Some env => f a env - | _ => None - end) vec oenv - ) mat env. - - Definition two_matrix_env_iter {n m} {A B} (f: A -> B -> df_env -> option df_env) - (env: option df_env) (v: matrix A n m) (w: matrix B n m) : option df_env := - let vw := matrix_zip v w in - matrix_env_iter (fun '(a,b) e => f a b e) env vw. - - Definition two_matrix_env_iter_alt {n m} {A B} (f: A -> B -> df_env -> option df_env) - (env: df_env) (v: matrix A n m) (w: matrix B n m) : option df_env := - list_env_iter (fun i env => list_env_iter (fun j env => f (mnth v i j) (mnth w i j) env) - (Some env) (bounded_seq0 m)) - (Some env) (bounded_seq0 n). - - - Program Definition addvar (x : var_type) (grad_env:df_env) := - (match snd x as y return snd x = y -> - definition_function_types_interp y -> - definition_function_types_interp y with - | DTfloat => fun pf grad => match vartlookup grad_env x with - | Some val => grad + ((coerce _ val):float) - | _ => grad - end - | DTVector n => fun pf grad => match vartlookup grad_env x with - | Some val => build_vector (fun i => (vnth grad i) + (vnth ((coerce _ val):vector float n) i)) - | _ => grad - end - | DTMatrix m n => fun pf grad => match vartlookup grad_env x with - | Some val => build_matrix (fun i j => (mnth ((coerce _ val):matrix float m n) i j) + (mnth grad i j)) - | _ => grad - end - end) (eq_refl _). - Next Obligation. - rewrite pf; reflexivity. - Defined. - Next Obligation. - rewrite pf; reflexivity. - Defined. - Next Obligation. - rewrite pf; reflexivity. - Defined. - - Definition gradenv_init1 (v : var_type) : env_entry_type := - mk_env_entry v - (match snd v as y return definition_function_types_interp y with - | DTfloat => 0 - | DTVector n => ConstVector n 0 - | DTMatrix n m => ConstMatrix n m 0 - end). - - Definition gradenv_init (dvars : list var_type) : df_env := - map gradenv_init1 dvars. - - Fixpoint df_eval_backprop_deriv {T Ann} (σ:df_env) (df:DefinedFunction Ann T) (grad_env:df_env) {struct df} : definition_function_types_interp T -> option df_env - := match df with - | Number _ _ => fun grad => Some grad_env - | Constant _ _ _ => fun grad => Some grad_env - | DVector n _ dfs => - fun grad => - ((fix tv_env_iter {n1 n2} - (oenv: option df_env) (v: vector _ n1) (w: vector _ n2) : option df_env := - match oenv,v,w with - | Some env, Vector.cons vx _ v', Vector.cons wx _ w' => - tv_env_iter (df_eval_backprop_deriv σ vx env wx) v' w' - | oenv',_,_ => oenv' - end) - n n - (Some grad_env) dfs grad ) - | DMatrix n m _ dfs => fun grad => - ((fix tm_env_iter {n1 m1 n2 m2} - (oenv: option df_env) (v: matrix _ n1 m1) (w: matrix _ n2 m2) : option df_env := - match oenv,v,w with - | Some env, Vector.cons vr _ v', Vector.cons wr _ w' => - let nenv := - ((fix tv_env_iter {n1 n2} - (oenv: option df_env) (v: vector _ n1) (w: vector _ n2) : option df_env := - match oenv,v,w with - | Some env, Vector.cons vx _ v', Vector.cons wx _ w' => - tv_env_iter (df_eval_backprop_deriv σ vx env wx) v' w' - | oenv',_,_ => oenv' - end) - m1 m2 - (Some env) vr wr ) in - tm_env_iter nenv v' w' - | oenv',_,_ => oenv' - end) - n m n m - (Some grad_env) dfs grad ) - | Var x _ => fun grad => - if vartlookup grad_env x then - Some (vart_update grad_env x (addvar x grad_env grad)) - else Some grad_env - | Plus _ l r => fun grad => - match df_eval_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_backprop_deriv σ r grad_env' grad - | _ => None - end - | Minus _ l r => fun grad => - match df_eval_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (-grad) - | _ => None - end - | Times _ l r => fun grad => - match df_eval σ l, df_eval σ r with - | Some le, Some re => - match df_eval_backprop_deriv σ l grad_env (re * grad) with - | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (le * grad) - | _ => None - end - | _, _ => None - end - | Divide _ l r => fun grad => - match df_eval σ l, df_eval σ r with - | Some le, Some re => - match df_eval_backprop_deriv σ l grad_env (grad / re) with - | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (- le / (re * re) * grad) - | _ => None - end - | _, _ => None - end - | Square _ e => fun grad => - match df_eval σ e with - | Some ee => df_eval_backprop_deriv σ e grad_env (2 * ee * grad) - | _ => None - end - | Exp _ e => fun grad => - match df_eval σ e with - | Some ee => df_eval_backprop_deriv σ e grad_env (grad * Fexp ee) - | _ => None - end - | Log _ e => fun grad => - match df_eval σ e with - | Some ee => df_eval_backprop_deriv σ e grad_env (grad / ee) - | _ => None - end - | Abs _ e => fun grad => - match df_eval σ e with - | Some ee => df_eval_backprop_deriv σ e grad_env (grad * (sign ee)) - | _ => None - end - | Sign _ e => fun grad => df_eval_backprop_deriv σ e grad_env 0 - | PSign _ e => fun grad => df_eval_backprop_deriv σ e grad_env 0 - | Max _ l r => fun grad => - match df_eval σ l, - df_eval σ r with - | Some le, Some re => - if le <= re then - (df_eval_backprop_deriv σ r grad_env grad) else - (df_eval_backprop_deriv σ l grad_env grad) - | _, _ => None - end - | VectorElem n _ l i => fun grad => - let grad' := fun k => if proj1_sig k == proj1_sig i then grad else 0 in - df_eval_backprop_deriv σ l grad_env (build_vector grad') - | MatrixElem m n _ l i j => fun grad => - let grad' := fun k1 k2 => - if (proj1_sig k1 == proj1_sig i) then - if (proj1_sig k2 == proj1_sig j) then grad else 0 - else 0 in - df_eval_backprop_deriv σ l grad_env (build_matrix grad') - | VectorDot n _ l r => fun grad => - match df_eval σ l, df_eval σ r with - | Some le, Some re => - match df_eval_backprop_deriv σ l grad_env (vmap (fun rv => rv*grad) re) with - | Some grad_env' => - df_eval_backprop_deriv σ r grad_env' (vmap (fun lv => lv*grad) le) - | _ => None - end - | _, _ => None - end - | VectorSum n _ l => fun grad => - df_eval_backprop_deriv σ l grad_env (ConstVector n grad) - | MatrixSum n m _ l => fun grad => - df_eval_backprop_deriv σ l grad_env (ConstMatrix n m grad) - | VectorScalMult n _ x r => fun grad => - match df_eval σ x, df_eval σ r with - | Some xe, Some re => - match df_eval_backprop_deriv σ x grad_env (vsum (vmap2 Fmult re grad)) with - | Some grad_env' => - df_eval_backprop_deriv σ r grad_env' (vmap (fun e => xe * e) grad) - | _ => None - end - | _, _ => None - end - | MatrixScalMult n m _ x r => fun grad => - match df_eval σ x, df_eval σ r with - | Some xe, Some re => - match df_eval_backprop_deriv σ x grad_env (msum (mmap2 Fmult re grad)) with - | Some grad_env' => - df_eval_backprop_deriv σ r grad_env' (mmap (fun e => e * xe) grad) - | _ => None - end - | _, _ => None - end - | MatrixVectorMult n m _ l r => fun grad => - match df_eval σ l, df_eval σ r with - | Some le, Some re => - match df_eval_backprop_deriv σ l grad_env (build_matrix (fun i j => (vnth grad i) * (vnth re j))) with - - - | Some grad_env' => - df_eval_backprop_deriv σ r grad_env' (matrix_vector_mult (transpose le) grad) - | _ => None - end - | _, _ => None - end - | MatrixVectorAdd n m _ l r => fun grad => - match df_eval_backprop_deriv σ l grad_env grad with - | Some grad_env' => - match list_env_iter - (fun i env => df_eval_backprop_deriv σ r env (vnth (transpose grad) i)) - (Some grad_env') (bounded_seq0 m) with - | Some grad_env'' => Some grad_env'' - | _ => None - end - | _ => None - end - | MatrixMult n m p _ l r => fun grad => - match df_eval σ l, df_eval σ r with - | Some le, Some re => - match df_eval_backprop_deriv σ l grad_env (matrix_mult grad (transpose re)) with - - - | Some grad_env' => - df_eval_backprop_deriv σ r grad_env' (matrix_mult (transpose le) grad) - | _ => None - end - | _, _ => None - end - | VectorPlus n _ l r => fun grad => - match df_eval_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_backprop_deriv σ r grad_env' grad - | _ => None - end - | VectorMinus n _ l r => fun grad => - match df_eval_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (vmap Fopp grad) - | _ => None - end - | MatrixPlus n m _ l r => fun grad => - match df_eval_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_backprop_deriv σ r grad_env' grad - | _ => None - end - | MatrixMinus n m _ l r => fun grad => - match df_eval_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (mmap Fopp grad) - | _ => None - end - | VectorApply n _ x s r => fun grad => - match df_eval σ r with - | Some re => - let xv := (x, DTfloat):var_type in - let s' := df_deriv s xv in - let ograd := - vmap (fun '(rei, g) => - - match df_eval (cons (mk_env_entry xv rei) nil) s' with - | Some se => Some (g * se) - | _ => None - end) - (vcombine re grad) in - match vectoro_to_ovector ograd with - | Some grad' => df_eval_backprop_deriv σ r grad_env grad' - | _ => None - end - | _ => None - end - | MatrixApply n m _ x s r => fun grad => - match df_eval σ r with - | Some re => - let xv := (x, DTfloat):var_type in - let s' := df_deriv s xv in - let ograd := - mmap (fun '(rei, g) => - match df_eval (cons (mk_env_entry xv rei) nil) s' with - | Some se => Some (g * se) - | _ => None - end) - (mcombine re grad) in - match matrixo_to_omatrix ograd with - | Some grad' => df_eval_backprop_deriv σ r grad_env grad' - | _ => None - end - | _ => None - end - | VLossfun n _ v1 v2 s l re => fun grad => - match df_eval σ l with - | Some le => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - let s' := df_deriv s xv1 in - let ograd := - vmap (fun '(lei, rei) => - let senv := cons (mk_env_entry xv1 lei) - (cons (mk_env_entry xv2 rei) nil) in - match df_eval senv s' with - | Some se => Some (grad * se) - | _ => None - end) - (vcombine le re) in - match vectoro_to_ovector ograd with - | Some grad' => df_eval_backprop_deriv σ l grad_env grad' - | _ => None - end - | _ => None - end - | MLossfun n m _ v1 v2 s l re => fun grad => - match df_eval σ l with - | Some le => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - let s' := df_deriv s xv1 in - let ograd := - mmap (fun '(lei, rei) => - let senv := cons (mk_env_entry xv1 lei) - (cons (mk_env_entry xv2 rei) nil) in - match df_eval senv s' with - | Some se => Some ((grad * se)/(FfromZ (Z.of_nat m))) - | _ => None - end) - (mcombine le re) in - match matrixo_to_omatrix ograd with - | Some grad' => df_eval_backprop_deriv σ l grad_env grad' - | _ => None - end - | _ => None - end - end. - - Definition lifted_type (B:Type) T - := match T with - | DTfloat => B - | DTVector n => vector B n - | DTMatrix m n => matrix B m n - end. - - Fixpoint df_eval_tree_backprop_deriv {T} (σ:df_env) (df:DefinedFunction EvalAnn T) (grad_env:df_env) {struct df} : definition_function_types_interp T -> option df_env - := match df with - | Number _ _ => fun grad => Some grad_env - | Constant _ _ _ => fun grad => Some grad_env - | DVector n _ dfs => fun grad => - ((fix tv_env_iter {n1 n2} - (oenv: option df_env) (v: vector _ n1) (w: vector _ n2) : option df_env := - match oenv,v,w with - | Some env, Vector.cons vx _ v', Vector.cons wx _ w' => - tv_env_iter (df_eval_tree_backprop_deriv σ vx env wx) v' w' - | oenv',_,_ => oenv' - end) - n n - (Some grad_env) dfs grad ) - | DMatrix n m _ dfs => fun grad => - ((fix tm_env_iter {n1 m1 n2 m2} - (oenv: option df_env) (v: matrix _ n1 m1) (w: matrix _ n2 m2) : option df_env := - match oenv,v,w with - | Some env, Vector.cons vr _ v', Vector.cons wr _ w' => - let nenv := - ((fix tv_env_iter {n1 n2} - (oenv: option df_env) (v: vector _ n1) (w: vector _ n2) : option df_env := - match oenv,v,w with - | Some env, Vector.cons vx _ v', Vector.cons wx _ w' => - tv_env_iter (df_eval_backprop_deriv σ vx env wx) v' w' - | oenv',_,_ => oenv' - end) - m1 m2 - (Some env) vr wr ) in - tm_env_iter nenv v' w' - | oenv',_,_ => oenv' - end) - n m n m - (Some grad_env) dfs grad ) - | Var x _ => fun grad => - if vartlookup grad_env x then - Some (vart_update grad_env x (addvar x grad_env grad)) - else Some grad_env - | Plus _ l r => fun grad => - match df_eval_tree_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' grad - | _ => None - end - | Minus _ l r => fun grad => - match df_eval_tree_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (-grad) - | _ => None - end - | Times _ l r => fun grad => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_backprop_deriv σ l grad_env (re * grad) with - | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (le * grad) - | _ => None - end - | Divide _ l r => fun grad => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_backprop_deriv σ l grad_env (grad / re) with - | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (- le / (re * re) * grad) - | _ => None - end - | Square _ e => fun grad => - let ee := get_annotation e in - df_eval_tree_backprop_deriv σ e grad_env (2 * ee * grad) - | Exp _ e => fun grad => - let ee := get_annotation e in - df_eval_tree_backprop_deriv σ e grad_env (grad * Fexp ee) - | Log _ e => fun grad => - let ee := get_annotation e in - df_eval_tree_backprop_deriv σ e grad_env (grad / ee) - | Abs _ e => fun grad => - let ee := get_annotation e in - df_eval_tree_backprop_deriv σ e grad_env (grad * (sign ee)) - | Sign _ e => fun grad => df_eval_tree_backprop_deriv σ e grad_env 0 - | PSign _ e => fun grad => df_eval_tree_backprop_deriv σ e grad_env 0 - | Max _ l r => fun grad => - let '(le,re) := (get_annotation l, get_annotation r) in - if le <= re then - (df_eval_tree_backprop_deriv σ r grad_env grad) else - (df_eval_tree_backprop_deriv σ l grad_env grad) - | VectorElem n _ l i => fun grad => - let grad' := build_vector (fun k => if proj1_sig k == proj1_sig i then grad else 0) in - df_eval_tree_backprop_deriv σ l grad_env grad' - | MatrixElem m n _ l i j => fun grad => - let grad' := build_matrix (fun k1 k2 => - if (proj1_sig k1 == proj1_sig i) then - if (proj1_sig k2 == proj1_sig j) then grad else 0 - else 0) in - df_eval_tree_backprop_deriv σ l grad_env grad' - | VectorDot n _ l r => fun grad => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_backprop_deriv σ l grad_env (vmap (fun rv => rv*grad) re) with - | Some grad_env' => - df_eval_tree_backprop_deriv σ r grad_env' (vmap (fun lv => lv*grad) le) - | _ => None - end - | VectorSum n _ l => fun grad => - df_eval_tree_backprop_deriv σ l grad_env (ConstVector n grad) - | MatrixSum n m _ l => fun grad => - df_eval_tree_backprop_deriv σ l grad_env (ConstMatrix n m grad) - | VectorScalMult n _ x r => fun grad => - let '(xe,re) := (get_annotation x, get_annotation r) in - match df_eval_tree_backprop_deriv σ x grad_env (vsum (vmap2 Fmult re grad)) with - | Some grad_env' => - df_eval_tree_backprop_deriv σ r grad_env' (vmap (fun ge => xe * ge) grad) - | _ => None - end - | MatrixScalMult n m _ x r => fun grad => - let '(xe,re) := (get_annotation x, get_annotation r) in - match df_eval_tree_backprop_deriv σ x grad_env (msum (mmap2 Fmult re grad)) with - | Some grad_env' => - df_eval_tree_backprop_deriv σ r grad_env' (mmap (fun ge => ge*xe) grad) - | _ => None - end - | MatrixVectorMult n m _ l r => fun grad => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_backprop_deriv σ l grad_env (build_matrix (fun i j => (vnth grad i) * (vnth re j))) with - | Some grad_env' => - df_eval_tree_backprop_deriv σ r grad_env' (matrix_vector_mult (transpose le) grad) - | _ => None - end - | MatrixVectorAdd n m _ l r => fun grad => - match df_eval_tree_backprop_deriv σ l grad_env grad with - | Some grad_env' => - match list_env_iter - (fun i env => df_eval_tree_backprop_deriv σ r env (vnth (transpose grad) i)) - (Some grad_env') (bounded_seq0 m) with - | Some grad_env'' => Some grad_env'' - | _ => None - end - | _ => None - end - | MatrixMult n m p _ l r => fun grad => - let '(le,re) := (get_annotation l, get_annotation r) in - match df_eval_tree_backprop_deriv σ l grad_env (matrix_mult grad (transpose re)) with - | Some grad_env' => - df_eval_tree_backprop_deriv σ r grad_env' (matrix_mult (transpose le) grad) - | _ => None - end - | VectorPlus n _ l r => fun grad => - match df_eval_tree_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' grad - | _ => None - end - | VectorMinus n _ l r => fun grad => - match df_eval_tree_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (vmap Fopp grad) - | _ => None - end - | MatrixPlus n m _ l r => fun grad => - match df_eval_tree_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' grad - | _ => None - end - | MatrixMinus n m _ l r => fun grad => - match df_eval_tree_backprop_deriv σ l grad_env grad with - | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (mmap Fopp grad) - | _ => None - end - | VectorApply n _ x s r => fun grad => - let re := get_annotation r in - let xv := (x, DTfloat):var_type in - let s' := df_deriv s xv in - let ograd := - vmap (fun '(rei, g) => - match df_eval (cons (mk_env_entry xv rei) nil) s' with - | Some se => Some (g * se) - | _ => None - end) - (vector_zip re grad) in - match vectoro_to_ovector ograd with - | Some grad' => df_eval_tree_backprop_deriv σ r grad_env grad' - | _ => None - end - | MatrixApply n m _ x s r => fun grad => - let re := get_annotation r in - let xv := (x, DTfloat):var_type in - let s' := df_deriv s xv in - let ograd := - mmap (fun '(rei, g) => - match df_eval (cons (mk_env_entry xv rei) nil) s' with - | Some se => Some (g * se) - | _ => None - end) - (matrix_zip re grad) in - match matrixo_to_omatrix ograd with - | Some grad' => df_eval_tree_backprop_deriv σ r grad_env grad' - | _ => None - end - | VLossfun n _ v1 v2 s l re => fun grad => - let le := get_annotation l in - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - let s' := df_deriv s xv1 in - let ograd := - vmap (fun '(lei, rei) => - let senv := cons (mk_env_entry xv1 lei) - (cons (mk_env_entry xv2 rei) nil) in - match df_eval senv s' with - | Some se => Some (grad * se) - | _ => None - end) - (vector_zip le re) in - match vectoro_to_ovector ograd with - | Some grad' => df_eval_tree_backprop_deriv σ l grad_env grad' - | _ => None - end - | MLossfun n m _ v1 v2 s l re => fun grad => - let le := get_annotation l in - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - let s' := df_deriv s xv1 in - let ograd := - mmap (fun '(lei, rei) => - let senv := cons (mk_env_entry xv1 lei) - (cons (mk_env_entry xv2 rei) nil) in - match df_eval senv s' with - | Some se => Some ((grad * se) / (FfromZ (Z.of_nat m))) - | _ => None - end) - (matrix_zip le re) in - match matrixo_to_omatrix ograd with - | Some grad' => df_eval_tree_backprop_deriv σ l grad_env grad' - | _ => None - end - end. - - Definition o_df_env_to_df_env (oenv : option df_env) : df_env := - match oenv with - | Some env => env - | _ => nil - end. - - - Definition backprop_lookup (oenv:option df_env) (a:var_type) : - option (definition_function_types_interp (snd a)) := - match oenv with - | Some env => - match vartlookup env a with - | Some val => Some val - | _ => None - end - | _ => None - end. - - Definition is_scalar_df_type (dft:definition_function_types) : Prop - := match dft with - | DTfloat => True - | _ => False - end. - - Fixpoint is_scalar_function {Ann} {T} (df:DefinedFunction Ann T) : Prop - := match df with - | Number _ _ => True - | Constant t _ _ => is_scalar_df_type t - | Var v _ => is_scalar_df_type (snd v) - | Plus _ l r => is_scalar_function l /\ is_scalar_function r - | Minus _ l r => is_scalar_function l /\ is_scalar_function r - | Times _ l r => is_scalar_function l /\ is_scalar_function r - | Divide _ l r => is_scalar_function l /\ is_scalar_function r - | Square _ e => is_scalar_function e - | Exp _ e => is_scalar_function e - | Log _ e => is_scalar_function e - | Abs _ e => is_scalar_function e - | Sign _ e => is_scalar_function e - | PSign _ e => is_scalar_function e - | Max _ l r => is_scalar_function l /\ is_scalar_function r - | _ => False - end. - - Fixpoint has_scalar_functions {Ann} {T} - (df:DefinedFunction Ann T) {struct df}: Prop - := match df with - | Number _ _ => True - | Constant _ _ _ => True - | DVector n _ vec => - ((fix vforal {n'} (v:vector (DefinedFunction Ann DTfloat) n') : Prop := - match v with - | Vector.nil => True - | Vector.cons vx _ v' => (has_scalar_functions vx) /\ (vforal v') - end) - n vec) - | DMatrix n m _ mat => - ((fix mforall {n' m'} (mat:matrix (DefinedFunction Ann DTfloat) n' m') : Prop := - match mat with - | Vector.nil => True - | Vector.cons matr _ mat' => - (mforall mat') /\ - ((fix vforal {n'} (v:vector (DefinedFunction Ann DTfloat) n') : Prop := - match v with - | Vector.nil => True - | Vector.cons vx _ v' => (has_scalar_functions vx) /\ (vforal v') - end) - m' matr) - end) - n m mat) - | Var _ _ => True - | Plus _ l r => (has_scalar_functions l) /\ (has_scalar_functions r) - | Minus _ l r => (has_scalar_functions l) /\ (has_scalar_functions r) - | Times _ l r => (has_scalar_functions l) /\ (has_scalar_functions r) - | Divide _ l r => (has_scalar_functions l) /\ (has_scalar_functions r) - | Square _ l => has_scalar_functions l - | Exp _ l => has_scalar_functions l - | Log _ l => has_scalar_functions l - | Abs _ l => has_scalar_functions l - | Sign _ l => has_scalar_functions l - | PSign _ l => has_scalar_functions l - | Max _ l r => (has_scalar_functions l) /\ (has_scalar_functions r) - | VectorDot n _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | VectorSum _ _ l => has_scalar_functions l - | MatrixSum _ _ _ l => has_scalar_functions l - | VectorElem _ _ vec i => has_scalar_functions vec - | MatrixElem _ _ _ mat i j => has_scalar_functions mat - | MatrixVectorMult _ _ _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | MatrixVectorAdd _ _ _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | MatrixMult _ _ _ _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | VectorPlus _ _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | VectorMinus _ _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | MatrixPlus _ _ _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | MatrixMinus _ _ _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | VectorScalMult _ _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | MatrixScalMult _ _ _ l r => (has_scalar_functions l) /\ - (has_scalar_functions r) - | VectorApply _ _ _ s l => is_scalar_function s /\ has_scalar_functions l - | MatrixApply _ _ _ _ s l => is_scalar_function s /\ has_scalar_functions l - | VLossfun _ _ _ _ s l _ => is_scalar_function s /\ has_scalar_functions l - | MLossfun _ _ _ _ _ s l _ => is_scalar_function s /\ has_scalar_functions l - end. - - Lemma is_scalar_function_has_scalar_functions {Ann} {T} (df:DefinedFunction Ann T) : - is_scalar_function df -> has_scalar_functions df. - Proof. - induction df; firstorder. - Qed. - - Hint Resolve is_scalar_function_has_scalar_functions. - - Definition DefinedFunction_ind_unit_has_scalar_functions - (P : forall (d : definition_function_types), DefinedFunction UnitAnn d -> Prop) - (f : forall (ann : UnitAnn DTfloat) (x : float), - P DTfloat (Number ann x)) - (f0 : forall (t : definition_function_types) - (ann : UnitAnn t) (x : definition_function_types_interp t), P t (Constant ann x)) - (f1 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (x : vector (DefinedFunction UnitAnn DTfloat) n) - (f: vforall (P DTfloat) x), - P (DTVector n) (DVector ann x)) - (f2 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m)) - (x : matrix (DefinedFunction UnitAnn DTfloat) n m) - (f: mforall (P DTfloat) x), - P (DTMatrix n m) (DMatrix ann x)) - (f3 : forall (v : var_type) (ann : UnitAnn (snd v)), - P (snd v) (Var v ann)) - (f4 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Plus ann l r)) - (f5 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Minus ann l r)) - (f6 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Times ann l r)) - (f7 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Divide ann l r)) - (f8 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Square ann e)) - (f9 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Exp ann e)) - (f10 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Log ann e)) - (f11 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Abs ann e)) - (f12 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Sign ann e)) - (f13 : forall (ann : UnitAnn DTfloat) - (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (PSign ann e)) - (f14 : forall (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn DTfloat), - P DTfloat l -> - forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Max ann l r)) - (f15 : forall (n : nat) (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn (DTVector n)), - P (DTVector n) l -> - forall r : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) r -> P DTfloat (VectorDot ann l r)) - (f16 : forall (n : nat) (ann : UnitAnn DTfloat) - (v : DefinedFunction UnitAnn (DTVector n)), P (DTVector n) v -> P DTfloat (VectorSum ann v)) - (f17 : forall (m n : nat) (ann : UnitAnn DTfloat) - (v : DefinedFunction UnitAnn (DTMatrix m n)), - P (DTMatrix m n) v -> P DTfloat (MatrixSum ann v)) - (f18 : forall (n : nat) (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn (DTVector n)), - P (DTVector n) l -> - forall i : {x : nat | (x < n)%nat}, P DTfloat (VectorElem ann l i)) - (f19 : forall (m n : nat) (ann : UnitAnn DTfloat) - (l : DefinedFunction UnitAnn (DTMatrix m n)), - P (DTMatrix m n) l -> - forall (i : {x : nat | (x < m)%nat}) (j : {x : nat | (x < n)%nat}), - P DTfloat (MatrixElem ann l i j)) - (f20 : forall (m n : nat) (ann : UnitAnn (DTVector m)) - (l : DefinedFunction UnitAnn (DTMatrix m n)), - P (DTMatrix m n) l -> - forall r : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) r -> P (DTVector m) (MatrixVectorMult ann l r)) - (f21 : forall (m n : nat) (ann : UnitAnn (DTMatrix m n)) - (l : DefinedFunction UnitAnn (DTMatrix m n)), - P (DTMatrix m n) l -> - forall r : DefinedFunction UnitAnn (DTVector m), - P (DTVector m) r -> P (DTMatrix m n) (MatrixVectorAdd ann l r)) - (f22 : forall (m p n : nat) (ann : UnitAnn (DTMatrix m n)) - (l : DefinedFunction UnitAnn (DTMatrix m p)), - P (DTMatrix m p) l -> - forall r : DefinedFunction UnitAnn (DTMatrix p n), - P (DTMatrix p n) r -> P (DTMatrix m n) (MatrixMult ann l r)) - (f23 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (l : DefinedFunction UnitAnn (DTVector n)), - P (DTVector n) l -> - forall r : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) r -> P (DTVector n) (VectorPlus ann l r)) - (f24 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (l : DefinedFunction UnitAnn (DTVector n)), - P (DTVector n) l -> - forall r : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) r -> P (DTVector n) (VectorMinus ann l r)) - (f25 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m)) - (l : DefinedFunction UnitAnn (DTMatrix n m)), - P (DTMatrix n m) l -> - forall r : DefinedFunction UnitAnn (DTMatrix n m), - P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixPlus ann l r)) - (f26 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m)) - (l : DefinedFunction UnitAnn (DTMatrix n m)), - P (DTMatrix n m) l -> - forall r : DefinedFunction UnitAnn (DTMatrix n m), - P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixMinus ann l r)) - (f27 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (x : DefinedFunction UnitAnn DTfloat), - P DTfloat x -> - forall l : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) l -> P (DTVector n) (VectorScalMult ann x l)) - (f28 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m)) - (x : DefinedFunction UnitAnn DTfloat), - P DTfloat x -> - forall l : DefinedFunction UnitAnn (DTMatrix n m), - P (DTMatrix n m) l -> P (DTMatrix n m) (MatrixScalMult ann x l)) - (f29 : forall (n : nat) (ann : UnitAnn (DTVector n)) - (v : SubVar) (s : DefinedFunction UnitAnn DTfloat), - is_scalar_function s -> - P DTfloat s -> - forall l : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) l -> P (DTVector n) (VectorApply ann v s l)) - (f30 : forall (m n : nat) (ann : UnitAnn (DTMatrix m n)) - (v : SubVar) (s : DefinedFunction UnitAnn DTfloat), - is_scalar_function s -> - P DTfloat s -> - forall l : DefinedFunction UnitAnn (DTMatrix m n), - P (DTMatrix m n) l -> P (DTMatrix m n) (MatrixApply ann v s l)) - (f31 : forall (n : nat) (ann : UnitAnn DTfloat) - (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat), - is_scalar_function s -> - P DTfloat s -> - forall l : DefinedFunction UnitAnn (DTVector n), - P (DTVector n) l -> forall r : vector float n, P DTfloat (VLossfun ann v1 v2 s l r)) - (f32 : forall (m n : nat) (ann : UnitAnn DTfloat) - (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat), - is_scalar_function s -> - P DTfloat s -> - forall l : DefinedFunction UnitAnn (DTMatrix m n), - P (DTMatrix m n) l -> - forall r : matrix float m n, P DTfloat (MLossfun ann v1 v2 s l r)) - : forall (d : definition_function_types) - (d0 : DefinedFunction UnitAnn d) - (hs:has_scalar_functions d0), P d d0. - Proof. - refine (fix - F (d : definition_function_types) - (d0 : DefinedFunction UnitAnn d) - {struct d0} : has_scalar_functions d0 -> P d d0 := - match d0 as d2 in (DefinedFunction _ d1) return has_scalar_functions d2 -> (P d1 d2) with - | Number ann x => fun hs => f ann x - | @Constant _ t ann x => fun hs => f0 t ann x - | @DVector _ n ann x => - fun hs => - f1 n ann x - ((fix F1 n (x:vector (DefinedFunction UnitAnn DTfloat) n) : vforall (P DTfloat) x := - match x with - | Vector.nil => Vector.Forall_nil (P DTfloat) - | Vector.cons h _ tl => Vector.Forall_cons _ _ _ (F DTfloat h _) (F1 _ tl) - end) n x) - | @DMatrix _ n m ann x => - fun hs => - f2 n m ann x - ((fix F2 n m (x:matrix (DefinedFunction UnitAnn DTfloat) n m) : mforall (P DTfloat) x := - match x with - | Vector.nil => Vector.Forall_nil (vforall (P DTfloat)) - - | Vector.cons h _ tl => - Vector.Forall_cons _ _ _ - ((fix F1 m (x:vector (DefinedFunction UnitAnn DTfloat) m) : vforall (P DTfloat) x := - match x with - | Vector.nil => Vector.Forall_nil (P DTfloat) - | Vector.cons h _ tl => Vector.Forall_cons _ _ _ (F DTfloat h _) (F1 _ tl) - end) m h) - (F2 _ _ tl) - end) n m x) - | Var v ann => fun hs => f3 v ann - | Plus ann l r => fun hs => f4 ann l (F DTfloat l (proj1 hs)) r (F DTfloat r (proj2 hs)) - | Minus ann l r => fun hs => f5 ann l (F DTfloat l _) r (F DTfloat r _) - | Times ann l r => fun hs => f6 ann l (F DTfloat l _) r (F DTfloat r _) - | Divide ann l r => fun hs => f7 ann l (F DTfloat l _) r (F DTfloat r _) - | Square ann e => fun hs => f8 ann e (F DTfloat e _) - | Exp ann e => fun hs => f9 ann e (F DTfloat e _) - | Log ann e => fun hs => f10 ann e (F DTfloat e _) - | Abs ann e => fun hs => f11 ann e (F DTfloat e _) - | Sign ann e => fun hs => f12 ann e (F DTfloat e _) - | PSign ann e => fun hs => f13 ann e (F DTfloat e _) - | Max ann l r => fun hs => f14 ann l (F DTfloat l _) r (F DTfloat r _) - | @VectorDot _ n ann l r => fun hs => f15 n ann l (F (DTVector n) l _) r (F (DTVector n) r _) - | @VectorSum _ n ann v => fun hs => f16 n ann v (F (DTVector n) v _) - | @MatrixSum _ m n ann v => fun hs => f17 m n ann v (F (DTMatrix m n) v _) - | @VectorElem _ n ann l i => fun hs => f18 n ann l (F (DTVector n) l _) i - | @MatrixElem _ m n ann l i j => fun hs => f19 m n ann l (F (DTMatrix m n) l _) i j - | @MatrixVectorMult _ m n ann l r => fun hs => - f20 m n ann l (F (DTMatrix m n) l _) r (F (DTVector n) r _) - | @MatrixVectorAdd _ m n ann l r => fun hs => - f21 m n ann l (F (DTMatrix m n) l _) r (F (DTVector m) r _) - | @MatrixMult _ m p n ann l r => fun hs => - f22 m p n ann l (F (DTMatrix m p) l _) r (F (DTMatrix p n) r _) - | @VectorPlus _ n ann l r => fun hs => f23 n ann l (F (DTVector n) l _) r (F (DTVector n) r _) - | @VectorMinus _ n ann l r => fun hs => f24 n ann l (F (DTVector n) l _) r (F (DTVector n) r _) - | @MatrixPlus _ n m ann l r => fun hs => f25 n m ann l (F (DTMatrix n m) l _) r (F (DTMatrix n m) r _) - | @MatrixMinus _ n m ann l r => fun hs => - f26 n m ann l (F (DTMatrix n m) l _) r (F (DTMatrix n m) r _) - | @VectorScalMult _ n ann x l => fun hs => f27 n ann x (F DTfloat x _) l (F (DTVector n) l _) - | @MatrixScalMult _ n m ann x l => fun hs => f28 n m ann x (F DTfloat x _) l (F (DTMatrix n m) l _) - | @VectorApply _ n ann v s l => fun hs => f29 n ann v s _ (F DTfloat s _) l (F (DTVector n) l _) - | @MatrixApply _ m n ann v s l => fun hs => - f30 m n ann v s _ (F DTfloat s _) l (F (DTMatrix m n) l _) - | @VLossfun _ n ann v1 v2 s l r => fun hs => - f31 n ann v1 v2 s _ (F DTfloat s _) l (F (DTVector n) l _) r - | @MLossfun _ m n ann v1 v2 s l r => fun hs => - f32 m n ann v1 v2 s _ (F DTfloat s _) l (F (DTMatrix m n) l _) r - end); simpl in hs; intuition. - - exact (proj1 (vforall_forall has_scalar_functions x) hs s). - - rewrite vforall_forall in hs. - specialize (hs s). - rewrite vforall_forall in hs. - specialize (hs s0). - exact hs. - Defined. - - Fixpoint is_df_rec_prop {Ann} {T} - (prop : forall TT:definition_function_types, - (DefinedFunction Ann TT) -> Prop) - (df:DefinedFunction Ann T) {struct df}: Prop - := prop T df /\ - match df with - | Number _ _ => True - | Constant _ _ _ => True - | DVector n _ vec => - vforall (is_df_rec_prop prop) vec - | DMatrix n m _ mat => - vforall (vforall (is_df_rec_prop prop)) mat - | Var _ _ => True - | Plus _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r) - | Minus _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r) - | Times _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r) - | Divide _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r) - | Square _ l => is_df_rec_prop prop l - | Exp _ l => is_df_rec_prop prop l - | Log _ l => is_df_rec_prop prop l - | Abs _ l => is_df_rec_prop prop l - | Sign _ l => is_df_rec_prop prop l - | PSign _ l => is_df_rec_prop prop l - | Max _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r) - | VectorDot n _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | VectorSum _ _ l => is_df_rec_prop prop l - | MatrixSum _ _ _ l => is_df_rec_prop prop l - | VectorElem _ _ vec i => is_df_rec_prop prop vec - | MatrixElem _ _ _ mat i j => is_df_rec_prop prop mat - | MatrixVectorMult _ _ _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | MatrixVectorAdd _ _ _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | MatrixMult _ _ _ _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | VectorPlus _ _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | VectorMinus _ _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | MatrixPlus _ _ _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | MatrixMinus _ _ _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | VectorScalMult _ _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | MatrixScalMult _ _ _ l r => (is_df_rec_prop prop l) /\ - (is_df_rec_prop prop r) - | VectorApply _ _ _ _ l => is_df_rec_prop prop l - | MatrixApply _ _ _ _ _ l => is_df_rec_prop prop l - | VLossfun _ _ _ _ _ l _ => is_df_rec_prop prop l - | MLossfun _ _ _ _ _ _ l _ => is_df_rec_prop prop l - end. - - Fixpoint df_strip_annotations {Ann} {T} - (df:DefinedFunction Ann T) {struct df}: DefinedFunction UnitAnn T - := - match df with - | Number _ x1 => Number tt x1 - | Constant t _ x => Constant tt x - | DVector n _ vec => DVector tt (vmap df_strip_annotations vec) - | DMatrix n m _ mat => DMatrix tt (vmap (vmap df_strip_annotations) mat) - | Var v _ => Var v tt - | Plus _ l r => Plus tt (df_strip_annotations l) (df_strip_annotations r) - | Minus _ l r => Minus tt (df_strip_annotations l) (df_strip_annotations r) - | Times _ l r => Times tt (df_strip_annotations l) (df_strip_annotations r) - | Divide _ l r => Divide tt (df_strip_annotations l) (df_strip_annotations r) - | Square _ l => Square tt (df_strip_annotations l) - | Exp _ l => Exp tt (df_strip_annotations l) - | Log _ l => Log tt (df_strip_annotations l) - | Abs _ l => Abs tt (df_strip_annotations l) - | Sign _ l => Sign tt (df_strip_annotations l) - | PSign _ l => PSign tt (df_strip_annotations l) - | Max _ l r => Max tt (df_strip_annotations l) (df_strip_annotations r) - | VectorDot n _ l r => VectorDot tt (df_strip_annotations l) (df_strip_annotations r) - | VectorSum n _ l => VectorSum tt (df_strip_annotations l) - | MatrixSum m n _ l => MatrixSum tt (df_strip_annotations l) - | VectorElem n _ vec i => VectorElem tt (df_strip_annotations vec) i - | MatrixElem m n _ mat i j => MatrixElem tt (df_strip_annotations mat) i j - | MatrixVectorMult m n _ l r => MatrixVectorMult tt (df_strip_annotations l) (df_strip_annotations r) - | MatrixVectorAdd m n _ l r => MatrixVectorAdd tt (df_strip_annotations l) (df_strip_annotations r) - | MatrixMult m p n _ l r => MatrixMult tt (df_strip_annotations l) (df_strip_annotations r) - | VectorPlus n _ l r => VectorPlus tt (df_strip_annotations l) (df_strip_annotations r) - | VectorMinus n _ l r => VectorMinus tt (df_strip_annotations l) (df_strip_annotations r) - | MatrixPlus m n _ l r => MatrixPlus tt (df_strip_annotations l) (df_strip_annotations r) - | MatrixMinus m n _ l r => MatrixMinus tt (df_strip_annotations l) (df_strip_annotations r) - | VectorScalMult n _ l r => VectorScalMult tt (df_strip_annotations l) (df_strip_annotations r) - | MatrixScalMult m n _ l r => MatrixScalMult tt (df_strip_annotations l) (df_strip_annotations r) - | VectorApply n _ v s l => VectorApply tt v (df_strip_annotations s) (df_strip_annotations l) - | MatrixApply m n _ v s l => MatrixApply tt v (df_strip_annotations s) (df_strip_annotations l) - | VLossfun n _ v1 v2 s l r => VLossfun tt v1 v2 (df_strip_annotations s) (df_strip_annotations l) r - | MLossfun m n _ v1 v2 s l r => MLossfun tt v1 v2 (df_strip_annotations s) (df_strip_annotations l) r - end. - - Require Import Program. - - Lemma df_strip_annotations_id {T} (df:DefinedFunction UnitAnn T) : df_strip_annotations df = df. - Proof. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case; simpl; trivial - ; destruct ann; trivial; try congruence. - - Case "DVector"%string. - f_equal. - erewrite vmap_ext; [apply vmap_id | ]; intros. - simpl. - destruct H0 as [??]; subst. - eapply H; eauto. - - Case "DMatrix"%string. - f_equal. - erewrite vmap_ext; [apply vmap_id | ]; intros. - simpl. - erewrite vmap_ext; [apply vmap_id | ]; intros. - simpl. - destruct H0 as [??]; subst. - destruct H1 as [??]; subst. - eapply H; eauto. - Qed. - - Definition df_eq_upto_annotations {Ann1 Ann2 T} - (df1:DefinedFunction Ann1 T) (df2:DefinedFunction Ann2 T) : Prop - := df_strip_annotations df1 = df_strip_annotations df2. - - Definition is_df_evalann_correct_top (σ:df_env) {T} (df:DefinedFunction EvalAnn T) - := df_eval σ df = Some (get_annotation df). - - Definition is_df_evalann_correct (σ:df_env) {T} (df:DefinedFunction EvalAnn T) - := is_df_rec_prop (@is_df_evalann_correct_top σ) df. - - Lemma is_df_rec_prop_top {Ann} {T} - {prop : forall TT:definition_function_types, - (DefinedFunction Ann TT) -> Prop} - {df:DefinedFunction Ann T} : - is_df_rec_prop prop df -> - prop _ df. - Proof. - destruct df; simpl; tauto. - Qed. - - Lemma df_eval_tree_correct {T Ann} (σ:df_env) (df:DefinedFunction Ann T) (dfann:DefinedFunction EvalAnn T): - df_eval_tree σ df = Some dfann -> - is_df_evalann_correct σ dfann. - Proof. - unfold is_df_evalann_correct, is_df_evalann_correct_top. - revert dfann. - DefinedFunction_cases (induction df) Case; simpl; intros dfann eqq - ; try solve[case_eq (df_eval_tree σ df1) - ; [intros adf1 a1eqq | intros a1eqq] - ; rewrite a1eqq in eqq - ; [| congruence] - ; (case_eq (df_eval_tree σ df2) - ; [intros adf2 a2eqq | intros a2eqq] - ; rewrite a2eqq in eqq - ; [| congruence] - ; inversion eqq; simpl - ; specialize (IHdf1 _ a1eqq) - ; specialize (IHdf2 _ a2eqq) - ; split; [| tauto] - ; apply is_df_rec_prop_top in IHdf1 - ; apply is_df_rec_prop_top in IHdf2 - ; simpl in IHdf1, IHdf2 - ; rewrite IHdf1, IHdf2 - ; trivial) - - | - case_eq (df_eval_tree σ df) - ; [intros adf aeqq | intros aeqq] - ; rewrite aeqq in eqq - ; [| congruence] - ; inversion eqq; simpl - ; specialize (IHdf _ aeqq) - ; split; [| tauto] - ; apply is_df_rec_prop_top in IHdf - ; simpl in IHdf - ; rewrite IHdf - ; trivial - ]. - - - Case "Number"%string. - inversion eqq; subst. - simpl; tauto. - - Case "Constant"%string. - inversion eqq; subst. - simpl; tauto. - - Case "DVector"%string. - match_option_in eqq. - invcs eqq. - simpl. - specialize (vectoro_to_ovector_forall_some_f eqq0) - ; simpl - ; clear eqq0; intros eqq0. - split. - + apply vectoro_to_ovector_forall_some_b_strong; intros i. - specialize (H _ _ (eqq0 i)). - apply is_df_rec_prop_top in H. - simpl in *. - rewrite vmap_nth; trivial. - + apply vforall_forall; eauto. - - Case "DMatrix"%string. - match_option_in eqq. - invcs eqq. - simpl. - unfold matrixo_to_omatrix in *. - specialize (vectoro_to_ovector_forall_some_f eqq0) - ; simpl - ; clear eqq0; intros eqq0. - split. - + apply vectoro_to_ovector_forall_some_b_strong; intros i. - apply vectoro_to_ovector_forall_some_b_strong; intros j. - specialize (eqq0 i). - specialize (vectoro_to_ovector_forall_some_f eqq0) - ; simpl - ; clear eqq0; intros eqq0. - specialize (eqq0 j). - specialize (H _ _ _ eqq0). - apply is_df_rec_prop_top in H. - simpl in *. - repeat rewrite vmap_nth; trivial. - + apply vforall_forall; intros. - apply vforall_forall; intros. - specialize (eqq0 i). - specialize (vectoro_to_ovector_forall_some_f eqq0) - ; simpl - ; clear eqq0; intros eqq0. - eauto. - - Case "Var"%string. - revert eqq. - case_eq (vartlookup σ v) ; [| congruence]. - intros. - inversion eqq; subst; simpl. - rewrite H; tauto. - - Case "VectorApply"%string. - case_eq (df_eval_tree σ df2) - ; [intros adf2 a2eqq | intros a2eqq] - ; rewrite a2eqq in eqq - ; [| congruence]. - specialize (IHdf2 _ a2eqq). - match_option_in eqq. - invcs eqq. - simpl. - split; trivial. - apply is_df_rec_prop_top in IHdf2. - simpl in IHdf2. - rewrite IHdf2; trivial. - - Case "MatrixApply"%string. - case_eq (df_eval_tree σ df2) - ; [intros adf2 a2eqq | intros a2eqq] - ; rewrite a2eqq in eqq - ; [| congruence]. - specialize (IHdf2 _ a2eqq). - match_option_in eqq. - invcs eqq. - simpl. - split; trivial. - apply is_df_rec_prop_top in IHdf2. - simpl in IHdf2. - rewrite IHdf2; trivial. - - Case "VLossfun"%string. - case_eq (df_eval_tree σ df2) - ; [intros adf2 a2eqq | intros a2eqq] - ; rewrite a2eqq in eqq - ; [| congruence]. - specialize (IHdf2 _ a2eqq). - match_option_in eqq. - invcs eqq. - simpl. - split; trivial. - apply is_df_rec_prop_top in IHdf2. - simpl in IHdf2. - rewrite IHdf2, eqq0; trivial. - - Case "MLossfun"%string. - case_eq (df_eval_tree σ df2) - ; [intros adf2 a2eqq | intros a2eqq] - ; rewrite a2eqq in eqq - ; [| congruence]. - specialize (IHdf2 _ a2eqq). - match_option_in eqq. - invcs eqq. - simpl. - split; trivial. - apply is_df_rec_prop_top in IHdf2. - simpl in IHdf2. - rewrite IHdf2, eqq0; trivial. - Qed. - - - Lemma df_eval_tree_deriv_correct {T} {σ:df_env} {df:DefinedFunction EvalAnn T} : - is_df_evalann_correct σ df -> - forall (xv:var_type), - (* let xv := (v, DTfloat) in *) - df_eval_tree_deriv σ df xv = df_eval_deriv σ df xv. - Proof. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case; - intro iscor; destruct iscor; - simpl; intros; trivial; unfold is_df_evalann_correct in * - ; try solve - [ - rewrite IHdf - ; trivial - | - assert (is_df_evalann_correct_top σ df) - ; [ apply is_df_rec_prop_top; trivial | - unfold is_df_evalann_correct_top in H1 - ; rewrite H1 - ; rewrite IHdf - ; trivial - ] - | - rewrite IHdf1; - [ rewrite IHdf2 - ; trivial - ; tauto - | - tauto - ] - | - rewrite IHdf1; - [ assert (is_df_evalann_correct_top σ df2); - [ apply is_df_rec_prop_top; trivial - | unfold is_df_evalann_correct_top in H1; - rewrite H1; trivial] - | tauto] - | - destruct H0; rewrite IHdf1; - [rewrite IHdf2; - [assert (is_df_evalann_correct_top σ df1); - [apply is_df_rec_prop_top; trivial - | assert (is_df_evalann_correct_top σ df2); - [ apply is_df_rec_prop_top; trivial - | - unfold is_df_evalann_correct_top in H2; - unfold is_df_evalann_correct_top in H3; - rewrite H2; rewrite H3; trivial ]] - | - tauto] - | - tauto] - ]. - - Case "DVector"%string. - f_equal. - apply FunctionalExtensionality.functional_extensionality. - intro. - apply H. - destruct H0. - rewrite vforall_forall in H1. - eauto. - - Case "DMatrix"%string. - f_equal. - apply FunctionalExtensionality.functional_extensionality. - intro. - apply FunctionalExtensionality.functional_extensionality. - intros. - apply H. - destruct H0. - rewrite vforall_forall in H1. - specialize (H1 x0). - rewrite vforall_forall in H1. - eauto. - Qed. - - Lemma df_eval_tree_backprop_deriv_correct {T} (σ gradenv:df_env) (df:DefinedFunction EvalAnn T) (grad : definition_function_types_interp T) : - is_df_evalann_correct σ df -> - df_eval_tree_backprop_deriv σ df gradenv grad = df_eval_backprop_deriv σ df gradenv grad. - Proof. - revert gradenv grad. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case; - intros; simpl;trivial - ; try solve [ - destruct H; - assert (is_df_evalann_correct σ df); - [ unfold is_df_evalann_correct; trivial - | apply is_df_rec_prop_top in H0; - rewrite IHdf; - [ unfold is_df_evalann_correct_top in H0; - rewrite H0; trivial - | trivial]] - | - destruct H; - apply IHdf; - unfold is_df_evalann_correct; trivial - | - destruct H; destruct H0; rewrite IHdf1; - [ case_eq (df_eval_backprop_deriv σ df1 gradenv grad); [|congruence]; - intros; - rewrite IHdf2; trivial - | unfold is_df_evalann_correct; trivial] - - | - destruct H; - assert (is_df_evalann_correct σ df2); - [ unfold is_df_evalann_correct; trivial - | apply is_df_rec_prop_top in H0; - unfold is_df_evalann_correct_top in H0; - rewrite H0; - match_destr; - rewrite IHdf1; trivial] - | - destruct H; destruct H0; assert (is_df_evalann_correct σ df1); - [ unfold is_df_evalann_correct; trivial - | assert (is_df_evalann_correct σ df2); - [ unfold is_df_evalann_correct; trivial - | rewrite IHdf1; trivial; - apply is_df_rec_prop_top in H0; - apply is_df_rec_prop_top in H1; - unfold is_df_evalann_correct_top in H0; - unfold is_df_evalann_correct_top in H1; - rewrite H0; rewrite H1; - match_destr; - rewrite IHdf2; trivial]] - ]. - - Case "DVector"%string. - destruct H0. - rewrite vforall_forall in H1. - unfold two_vector_env_iter_alt. - f_equal; apply FunctionalExtensionality.functional_extensionality; intros - ; apply FunctionalExtensionality.functional_extensionality ; intros. - apply H; unfold is_df_evalann_correct; apply H1. - - Case "DMatrix"%string. - destruct H0. - rewrite vforall_forall in H1. - unfold two_matrix_env_iter_alt. - f_equal; apply FunctionalExtensionality.functional_extensionality; intros - ; apply FunctionalExtensionality.functional_extensionality; intros. - f_equal; apply FunctionalExtensionality.functional_extensionality; intros - ; apply FunctionalExtensionality.functional_extensionality; intros. - apply H; unfold is_df_evalann_correct. - specialize (H1 x0). - rewrite vforall_forall in H1; apply H1. - - Case "MatrixVectorAdd"%string. - destruct H. - destruct H0. - assert (is_df_evalann_correct σ df1). - unfold is_df_evalann_correct; trivial. - assert (is_df_evalann_correct σ df2). - unfold is_df_evalann_correct; trivial. - rewrite IHdf1; trivial. - match_destr. - assert - (list_env_iter - (fun (i : {m' : nat | (m' < n)%nat}) (env : df_env) => - df_eval_tree_backprop_deriv σ df2 env (transpose grad i)) - (Some d) (bounded_seq0 n) = - list_env_iter - (fun (i : {m' : nat | (m' < n)%nat}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad i)) (Some d) - (bounded_seq0 n)). - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite IHdf2; trivial. - rewrite H4; trivial. - Qed. - - Lemma df_eval_ignores_ann {Ann T} {σ:df_env} - (df:DefinedFunction Ann T) : - df_eval σ df = df_eval σ (df_strip_annotations df). - Proof. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case; simpl; trivial - ; try solve [ - rewrite IHdf; trivial - | - rewrite IHdf1; - case_eq (df_eval σ (df_strip_annotations df1)); [|congruence]; - intros; rewrite IHdf2; trivial - | - rewrite IHdf1; trivial; - case_eq (df_eval σ (df_strip_annotations df2)); [|congruence]; - intros; f_equal; - apply FunctionalExtensionality.functional_extensionality; intros; - f_equal; rewrite df_strip_annotations_id; trivial - | - rewrite IHdf1; trivial; - case_eq (df_eval σ (df_strip_annotations df2)); [|congruence]; - intros; f_equal; - rewrite df_strip_annotations_id; trivial - ]. - - - Case "DVector"%string. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite H. - rewrite vmap_nth; trivial. - - Case "DMatrix"%string. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H x0). - rewrite H. - rewrite vmap_nth. - rewrite vmap_nth; trivial. - Qed. - - Lemma df_eval_ignores_ann2 {Ann1 Ann2 T} {σ:df_env} - (df1:DefinedFunction Ann1 T) (df2:DefinedFunction Ann2 T) : - df_eq_upto_annotations df1 df2 -> - df_eval σ df1 = df_eval σ df2. - Proof. - assert (df_eval σ df1 = df_eval σ (df_strip_annotations df1)) by apply df_eval_ignores_ann. - assert (df_eval σ df2 = df_eval σ (df_strip_annotations df2)) by apply df_eval_ignores_ann. - congruence. - Qed. - - Lemma df_eval_deriv_ignores_ann {Ann T} {σ:df_env} - (df:DefinedFunction Ann T) : - forall (xv:var_type), - df_eval_deriv σ df xv = df_eval_deriv σ (df_strip_annotations df) xv. - Proof. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case; simpl; trivial - ; try solve - [ - intro; rewrite IHdf1; - case_eq (df_eval_deriv σ (df_strip_annotations df1) xv); [|congruence]; - intros; - rewrite IHdf2; trivial - | - intro; rewrite df_eval_ignores_ann; - case_eq (df_eval σ (df_strip_annotations df1)); [|congruence]; - intros; rewrite IHdf1; intros; - case_eq (df_eval_deriv σ (df_strip_annotations df1) xv); [|congruence]; - intros; rewrite df_eval_ignores_ann; - case_eq (df_eval σ (df_strip_annotations df2)); [|congruence]; - intros; rewrite IHdf2; trivial - | - intros; rewrite df_eval_ignores_ann; - case_eq (df_eval σ (df_strip_annotations df)); [|congruence]; - intros; rewrite IHdf; intros; - case_eq (df_eval_deriv σ (df_strip_annotations df) xv); [|congruence]; - trivial - | - intros; rewrite IHdf; trivial - | - intro; rewrite df_eval_ignores_ann; - case_eq (df_eval σ (df_strip_annotations df2)); [|congruence]; - intros; rewrite IHdf1; intros; - rewrite df_strip_annotations_id; trivial - ]. - - - Case "DVector"%string. - intros. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite H. - rewrite vmap_nth; trivial. - - Case "DMatrix"%string. - intros. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H x0). - rewrite H. - rewrite vmap_nth. - rewrite vmap_nth; trivial. - - Case "Max"%string. - intro; rewrite df_eval_ignores_ann. - case_eq (df_eval σ (df_strip_annotations df1)); [|congruence]. - intros. - rewrite df_eval_ignores_ann. - case_eq (df_eval σ (df_strip_annotations df2)); [|congruence]. - intros. - rewrite IHdf1. - rewrite IHdf2; trivial. - Qed. - - Lemma df_eval_deriv_ignores_ann2 {Ann1 Ann2 T} {σ:df_env} - (df1:DefinedFunction Ann1 T) (df2:DefinedFunction Ann2 T) : - forall (xv:var_type), - df_eq_upto_annotations df1 df2 -> - df_eval_deriv σ df1 xv = df_eval_deriv σ df2 xv. - Proof. - intro. - assert (df_eval_deriv σ df1 xv = df_eval_deriv σ (df_strip_annotations df1) xv) by apply df_eval_deriv_ignores_ann. - assert (df_eval_deriv σ df2 xv = df_eval_deriv σ (df_strip_annotations df2) xv) by apply df_eval_deriv_ignores_ann. - congruence. - Qed. - - Lemma is_scalar_function_scalar {Ann} {T} (df:DefinedFunction Ann T) : - is_scalar_function df -> is_scalar_df_type T. - Proof. - induction df; simpl; trivial. - Qed. - - - - Definition definition_function_types_map_base (f:Type->Type) (dft:definition_function_types): Type - := match dft with - | DTfloat => f float - | DTVector n => Vector (f float) n - | DTMatrix m n => Matrix (f float) m n - end. - - Definition definition_function_types_subgradient (dft:definition_function_types) - := definition_function_types_map_base (fun t => list (list t)) dft. - - - Definition df_eval_gradient {T} σ (df:DefinedFunction UnitAnn T) (lv:list var_type) : option (list (definition_function_types_interp T)) - := listo_to_olist (map (df_eval_deriv σ df) lv). - - Definition combine_prod (l1 l2 : list (list float)) : list (list (float * float)) - := let l12 := list_prod l1 l2 - in map (fun '(x,y) => combine x y) l12. -(* - Fixpoint df_eval_subgradient {dft:definition_function_types} (σ:df_env) (df:DefinedFunction dft) (lv:list SubVar) : option (definition_function_types_subgradient dft) - := (match df with - | Number _ => Some ((map (fun _ => 0) lv) :: nil) - | DVector n v => vectoro_to_ovector (vmap (fun x => df_eval_subgradient σ x lv) v) - | DMatrix n m df => matrixo_to_omatrix (vmap (fun x => vmap (fun y => df_eval_subgradient σ y lv) x) df) - | Var x => Some ((map (fun v => if x == v then 1 else 0) lv) :: nil) - | Plus l r => - match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with - | Some ld, Some rd => Some (map (map (fun '(x, y) => x+y)) (combine_prod ld rd)) - | _, _ => None - end - | Minus l r => - match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with - | Some ld, Some rd => Some (map (map (fun '(x, y) => x-y)) (combine_prod ld rd)) - | _, _ => None - end - | Times l r => - match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with - | Some le, Some ld, Some re, Some rd => - Some (map (map (fun '(lp,rp) => lp*re + le*rp)) (combine_prod ld rd)) - | _, _, _, _ => None - end - | Divide l r => - match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with - | Some le, Some ld, Some re, Some rd => - Some (map (map (fun '(lp,rp) => (lp*re - le*rp)/(re * re))) (combine_prod ld rd)) - | _, _, _, _ => None - end - | Square e => - match df_eval σ e, df_eval_subgradient σ e lv with - | Some ee, Some ed => Some (map (map (fun pd => 2 * ee * pd)) ed) - | _, _ => None - end - | Exp e => - match df_eval σ e, df_eval_subgradient σ e lv with - | Some ee, Some ed => Some (map (map (fun pd => pd * Fexp ee)) ed) - | _, _ => None - end - | Log e => - match df_eval σ e, df_eval_subgradient σ e lv with - | Some ee, Some ed => Some (map (map (fun pd => (pd / ee))) ed) - | _, _ => None - end - | Abs e => - match df_eval σ e, df_eval_subgradient σ e lv with - | Some ee, Some ed => - if Feq ee 0 then Some (ed ++ (map (map (fun ep => -ep)) ed)) - else Some (map (map (fun ed => (ed * (sign ee)))) ed) - | _, _ => None - end - | Sign e => - match df_eval σ e with - | Some ee => Some ((map (fun _ => 0) lv) :: nil ) - | _ => None - end - | PSign e => - match df_eval σ e with - | Some ee => Some ((map (fun _ => 0) lv) :: nil ) - | _ => None - end - | Max l r => - match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with - | Some le, Some ld, Some re, Some rd => - if Feq le re then Some (ld ++ rd) - else if le > re then Some ld - else Some rd - | _, _, _, _ => None - end - | VectorElem n l i => - match (df_eval_subgradient σ l lv) with - | Some l' => Some (l' i) - | _ => None - end - | MatrixElem m n l i j => - match (df_eval_subgradient σ l lv) with - | Some l' => Some (l' i j) - | _ => None - end - | VectorSum n l => - match df_eval_subgradient σ l lv with - | Some l' => - Some (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp)) - (combine_prod a b)) - ((map (fun _ => 0) lv)::nil) l') - | _ => None - end - | VectorDot n l r => - match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with - | Some le, Some ld, Some re, Some rd => - Some (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp)) - (combine_prod a b)) - ((map (fun _ => 0) lv)::nil) - (fun i => map (map (fun '(lp,rp) => lp*(re i) + (le i)*rp)) - (combine_prod (ld i) (rd i)))) - | _, _, _, _ => None - end - | VectorScalMult n x r => - match df_eval σ x, df_eval_subgradient σ x lv, df_eval σ r, df_eval_subgradient σ r lv with - | Some xe, Some xd, Some re, Some rd => - Some (fun j => map (map (fun '(xp,rp) => xe * rp + xp * (re j))) (combine_prod xd (rd j))) - | _, _, _, _ => None - end - | MatrixScalMult n m x r => - match df_eval σ x, df_eval_subgradient σ x lv, df_eval σ r, df_eval_subgradient σ r lv with - | Some xe, Some xd, Some re, Some rd => - Some (fun i j => map (map (fun '(xp,rp) => xe * rp + xp * (re i j))) (combine_prod xd (rd i j))) - - | _, _, _, _ => None - end - | MatrixVectorMult n m l r => - match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with - | Some le, Some ld, Some re, Some rd => - Some (fun i => - (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp)) - (combine_prod a b)) - ((map (fun _ => 0) lv)::nil) - (fun j => map (map (fun '(lp,rp) => lp*(re j) + (le i j)*rp)) - (combine_prod (ld i j) (rd j))))) - | _, _, _, _ => None - end - | MatrixMult n m p l r => - match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with - | Some le, Some ld, Some re, Some rd => - Some (fun i k => - (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp)) - (combine_prod a b)) - ((map (fun _ => 0) lv)::nil) - (fun j => map (map (fun '(lp,rp) => lp*(re j k) + (le i j)*rp)) - (combine_prod (ld i j) (rd j k))))) - | _, _, _, _ => None - end - | VectorPlus n l r => - match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with - | Some ld, Some rd => Some (fun i => (map (map (fun '(x, y) => x+y)) (combine_prod (ld i) (rd i)))) - | _, _ => None - end - | VectorMinus n l r => - match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with - | Some ld, Some rd => Some (fun i => (map (map (fun '(x, y) => x-y)) (combine_prod (ld i) (rd i)))) - | _, _ => None - end - | MatrixPlus n m l r => - match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with - | Some ld, Some rd => Some (fun i j => (map (map (fun '(x, y) => x+y)) (combine_prod (ld i j) (rd i j)))) - | _, _ => None - end - | MatrixMinus n m l r => - match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with - | Some ld, Some rd => Some (fun i j => (map (map (fun '(x, y) => x-y)) (combine_prod (ld i j) (rd i j)))) - | _, _ => None - end - | VectorApply n x s r => - match df_eval σ r, df_eval_subgradient σ r lv with - | Some re, Some rd => - vectoro_to_ovector - (fun i => match df_eval_subgradient (cons (x, re i) σ) s lv with - | Some sd => - Some (map (map (fun '(x, y) => x*y)) (combine_prod (rd i) sd)) - | _ => None - end) - | _, _ => None - end - | Lossfun n v1 v2 s l r => - match df_eval σ l, df_eval_subgradient σ l lv with - | Some le, Some ld => - match (vectoro_to_ovector - (fun i => match df_eval_subgradient (cons (v1, (le i)) (cons (v2, r i) σ)) s lv with - | Some sd => Some (map (map (fun '(x, y) => x*y)) (combine_prod (ld i) sd)) - | _ => None - end)) with - | Some vv => Some (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp)) - (combine_prod a b)) - ((map (fun _ => 0) lv)::nil) vv) - | _ => None - end - | _, _ => None - end - end). -*) - End deriv2. - - Definition dft_one (dft:definition_function_types) : definition_function_types_interp dft - := match dft with - | DTfloat => 1 - | DTVector n => fun _ => 1 - | DTMatrix m n => fun _ _ => 1 - end. - - Section scalar_ind. - - Fixpoint is_scalar_function_ind_gen {Ann} - {P:forall {T}, DefinedFunction Ann T->Prop} - (fnumber:forall ann x, P (Number ann x)) - (fconstant:forall (ann:Ann DTfloat) x, P (@Constant _ DTfloat ann x)) - (fvar:forall sv ann, P (@Var _ (sv,DTfloat) ann)) - (fplus:forall a l r, P l -> P r -> P (Plus a l r)) - (fminus:forall a l r, P l -> P r -> P (Minus a l r)) - (ftimes:forall a l r, P l -> P r -> P (Times a l r)) - (fdivide:forall a l r, P l -> P r -> P (Divide a l r)) - (fsquare:forall a e, P e -> P (Square a e)) - (fexp:forall a e, P e -> P (Exp a e)) - (flog:forall a e, P e -> P (Log a e)) - (fabs:forall a e, P e -> P (Abs a e)) - (fsign:forall a e, P e -> P (Sign a e)) - (fpsign:forall a e, P e -> P (PSign a e)) - (fmax:forall a l r, P l -> P r -> P (Max a l r)) - {T} - (df:DefinedFunction Ann T) {struct df} : is_scalar_function df -> P df. - Proof. - induction df; simpl; intros isc; try tauto. - - apply fnumber. - - destruct t; simpl in isc; try tauto. - apply fconstant. - - destruct v. - destruct d; simpl in isc; try tauto. - apply fvar. - - apply fplus. - + apply IHdf1; tauto. - + apply IHdf2; tauto. - - apply fminus. - + apply IHdf1; tauto. - + apply IHdf2; tauto. - - apply ftimes. - + apply IHdf1; tauto. - + apply IHdf2; tauto. - - apply fdivide. - + apply IHdf1; tauto. - + apply IHdf2; tauto. - - apply fsquare. - + apply IHdf; tauto. - - apply fexp. - + apply IHdf; tauto. - - apply flog. - + apply IHdf; tauto. - - apply fabs. - + apply IHdf; tauto. - - apply fsign. - + apply IHdf; tauto. - - apply fpsign. - + apply IHdf; tauto. - - apply fmax. - + apply IHdf1; tauto. - + apply IHdf2; tauto. - Qed. - - Definition is_scalar_function_ind {Ann} - {P:DefinedFunction Ann DTfloat->Prop} - (fnumber:forall ann x, P (Number ann x)) - (fconstant:forall (ann:Ann DTfloat) x, P (@Constant _ DTfloat ann x)) - (fvar:forall sv ann, P (@Var _ (sv,DTfloat) ann)) - (fplus:forall a l r, P l -> P r -> P (Plus a l r)) - (fminus:forall a l r, P l -> P r -> P (Minus a l r)) - (ftimes:forall a l r, P l -> P r -> P (Times a l r)) - (fdivide:forall a l r, P l -> P r -> P (Divide a l r)) - (fsquare:forall a e, P e -> P (Square a e)) - (fexp:forall a e, P e -> P (Exp a e)) - (flog:forall a e, P e -> P (Log a e)) - (fabs:forall a e, P e -> P (Abs a e)) - (fsign:forall a e, P e -> P (Sign a e)) - (fpsign:forall a e, P e -> P (PSign a e)) - (fmax:forall a l r, P l -> P r -> P (Max a l r)) - (df:DefinedFunction Ann DTfloat) : is_scalar_function df -> P df. - Proof. - apply (@is_scalar_function_ind_gen _ (fun t => match t with - | DTfloat => fun df => P df - | _ => fun _ => False - end)); trivial. - Qed. - - Definition vartlookup_eq (l1 l2:df_env) : Prop := forall a, vartlookup l1 a = vartlookup l2 a. - - Global Instance vartlookup_eq_equiv : Equivalence vartlookup_eq. - Proof. - unfold vartlookup_eq. - constructor; red. - - intros; reflexivity. - - intros; eauto. - - intro; etransitivity; eauto. - Qed. - - End scalar_ind. - - Lemma lookup_update (xv : var_type) (gradenv : df_env) - (val : definition_function_types_interp (snd xv)) : - vartlookup (vart_update gradenv xv val) xv = Some val. - Proof. - induction gradenv; simpl. - - destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence]. - refl_simpler; simpl; trivial. - - destruct a; simpl. - case_eq (@equiv_dec var_type _ _ _ xv x); simpl; intros. - + destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence]. - refl_simpler; simpl; trivial. - + rewrite H; trivial. - Qed. - - Lemma lookup_update_neq (xv1 xv2 : var_type) (gradenv : df_env) - (val : definition_function_types_interp (snd xv1)) : xv1 <> xv2 -> - vartlookup (vart_update gradenv xv1 val) xv2 = vartlookup gradenv xv2. - Proof. - intros neq. - induction gradenv; simpl. - - destruct (@equiv_dec var_type _ _ _ xv2 xv1); congruence. - - destruct a; simpl. - case_eq (@equiv_dec var_type _ _ _ xv1 x); simpl; intros. - + destruct (@equiv_dec var_type _ _ _ xv2 xv1); [congruence | ]. - destruct (@equiv_dec var_type _ _ _ xv2 x); congruence. - + destruct (@equiv_dec var_type _ _ _ xv2 x); congruence. - Qed. - - Lemma lookup_update2 (xv : var_type) (gradenv : df_env) - (val : definition_function_types_interp (snd xv)) : - vartlookup ((mk_env_entry xv val) :: gradenv) xv = Some val. - Proof. - induction gradenv; simpl. - - destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence]. - refl_simpler; simpl; trivial. - - destruct a; simpl. - case_eq (@equiv_dec var_type _ _ _ xv x); simpl; intros. - + destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence]. - refl_simpler; simpl; trivial. - + destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence]. - refl_simpler; simpl; trivial. - Qed. - - -Tactic Notation "DefinedFunction_scalar_cases" tactic(first) ident(c) := - first; - [ Case_aux c "Number"%string - | Case_aux c "Constant"%string - | Case_aux c "Var"%string - | Case_aux c "Plus"%string - | Case_aux c "Minus"%string - | Case_aux c "Times"%string - | Case_aux c "Divide"%string - | Case_aux c "Square"%string - | Case_aux c "Exp"%string - | Case_aux c "Log"%string - | Case_aux c "Abs"%string - | Case_aux c "Sign"%string - | Case_aux c "PSign"%string - | Case_aux c "Max"%string]. - - Lemma df_eval_backprop_deriv_preserves_lookup_not_none {Ann T} {env} {grad gradenv d} {df:DefinedFunction Ann T} : - df_eval_backprop_deriv env df gradenv grad = Some d -> - forall xv, - vartlookup gradenv xv <> None -> - vartlookup d xv <> None. - Proof. - simpl. - revert grad gradenv d. - DefinedFunction_cases (induction df) Case; simpl. - - Case "Number"%string; intros; inversion H; subst; easy. - - Case "Constant"%string; intros; inversion H; subst; easy. - - Case "DVector"%string. - intros grad. - unfold two_vector_env_iter_alt. - induction (bounded_seq0 n). - simpl. - intros. - inversion H0; subst; trivial. - simpl. - intros gradenv d. - case_eq (df_eval_backprop_deriv env (x a) gradenv (grad a)). - intros. - specialize (H a (grad a) gradenv d0). - specialize (IHl d0 d). - apply IHl; trivial. - apply H; trivial. - intros. - assert (list_env_iter - (fun (i : {n' : nat | (n' < n)%nat}) (env0 : df_env) => - df_eval_backprop_deriv env (x i) env0 (grad i)) None l = None) - by apply list_env_iter_none. - intros; rewrite H1 in H3; discriminate. - - Case "DMatrix"%string. - intros grad. - unfold two_matrix_env_iter_alt. - induction (bounded_seq0 n); simpl. - { intros; inversion H0; subst; trivial. } - intros gradenv d eqq. - case_eq ((list_env_iter - (fun (j : {m' : nat | (m' < m)%nat}) (env0 : df_env) => - df_eval_backprop_deriv env (x a j) env0 (grad a j)) (Some gradenv) - (bounded_seq0 m))) - ; [ intros dd ddeqq | intros ddeqq] - ; rewrite ddeqq in eqq - ; simpl in eqq - ; [| destruct l; simpl; discriminate]. - specialize (IHl _ _ eqq). - cut (forall xv : var_type, vartlookup gradenv xv <> None -> vartlookup dd xv <> None) - ; [ eauto | ]. - clear d IHl eqq. - revert gradenv dd ddeqq. - induction (bounded_seq0 m); simpl - ; intros gradenv dd ddeqq - ; simpl in ddeqq. - { inversion ddeqq; subst; trivial. } - case_eq (df_eval_backprop_deriv env (x a a0) gradenv (grad a a0)) - ; [intros dd2 ddeqq2 | intros ddeqq2] - ; rewrite ddeqq2 in ddeqq - ; simpl in ddeqq - ; [| destruct l0; simpl; discriminate]. - eauto. - - Case "Var"%string. - intros. - destruct (vartlookup gradenv v) ; [|congruence]. - intros. - inversion H. - destruct (vart_dec v xv). - + subst; rewrite lookup_update. - discriminate. - + rewrite lookup_update_neq; trivial. - - Case "Plus"%string. - intros grad gradenv. - case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence]. - intros. - specialize (IHdf1 grad gradenv d). - specialize (IHdf2 grad d d0). - apply IHdf2; trivial. - apply IHdf1; trivial. - - Case "Minus"%string. - intros grad gradenv. - case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence]. - intros. - specialize (IHdf1 grad gradenv d). - specialize (IHdf2 (-grad) d d0). - apply IHdf2; trivial. - apply IHdf1; trivial. - - Case "Times"%string. - intros grad gradenv d. - destruct (df_eval env df1) ; [|congruence]. - destruct (df_eval env df2) ; [|congruence]. - case_eq (df_eval_backprop_deriv env df1 gradenv (d1 * grad)) ; [|congruence]. - intros d2. - specialize (IHdf1 (d1 * grad) gradenv d2). - specialize (IHdf2 (d0 * grad) d2 d). - intros. - apply IHdf2; trivial. - apply IHdf1; trivial. - - Case "Divide"%string. - intros grad gradenv d. - destruct (df_eval env df1) ; [|congruence]. - destruct (df_eval env df2) ; [|congruence]. - case_eq (df_eval_backprop_deriv env df1 gradenv (grad / d1)) ; [|congruence]. - intros d2. - specialize (IHdf1 (grad / d1) gradenv d2). - specialize (IHdf2 (- d0 / (d1 * d1) * grad) d2 d). - intros. - apply IHdf2; trivial. - apply IHdf1; trivial. - - Case "Square"%string; intros. - destruct (df_eval env df) ; [|congruence]. - specialize (IHdf (2 * d0 * grad) gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "Exp"%string; intros. - destruct (df_eval env df) ; [|congruence]. - specialize (IHdf (grad * Fexp d0) gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "Log"%string; intros. - destruct (df_eval env df) ; [|congruence]. - specialize (IHdf (grad / d0) gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "Abs"%string; intros. - destruct (df_eval env df) ; [|congruence]. - specialize (IHdf (grad * sign d0) gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "Sign"%string; intros. - specialize (IHdf 0 gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "PSign"%string; intros. - specialize (IHdf 0 gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "Max"%string; intros. - destruct (df_eval env df1) ; [|congruence]. - destruct (df_eval env df2) ; [|congruence]. - destruct (d0 <= d1). - specialize (IHdf2 grad gradenv d). - apply IHdf2. - apply H. - trivial. - specialize (IHdf1 grad gradenv d). - apply IHdf1. - apply H. - trivial. - - Case "VectorDot"%string. - intros grad gradenv d. - destruct (df_eval env df1) ; [|congruence]. - destruct (df_eval env df2) ; [|congruence]. - case_eq (df_eval_backprop_deriv env df1 gradenv - (vmap (fun rv : float => rv * grad) d1)); [|congruence]. - intros. - specialize (IHdf1 (vmap (fun rv : float => rv * grad) d1) gradenv d2). - specialize (IHdf2 (vmap (fun lv : float => lv * grad) d0) d2 d). - apply IHdf2. - apply H0. - apply IHdf1. - apply H. - apply H1. - - Case "VectorSum"%string. - intros. - specialize (IHdf (ConstVector n grad) gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "MatrixSum"%string. - intros. - specialize (IHdf (ConstMatrix m n grad) gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "VectorElem"%string. - intros. - specialize (IHdf (fun k : {n' : nat | (n' < n)%nat} => - if equiv_dec (proj1_sig k) (proj1_sig i) - then grad else 0) gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "MatrixElem"%string. - intros. - specialize (IHdf (fun (k1 : {n' : nat | (n' < m)%nat}) - (k2 : {m' : nat | (m' < n)%nat}) => - if equiv_dec (proj1_sig k1) (proj1_sig i) - then if equiv_dec (proj1_sig k2) (proj1_sig j) then grad else 0 - else 0) gradenv d). - apply IHdf. - apply H. - apply H0. - - Case "MatrixVectorMult"%string. - intros grad gradenv d. - destruct (df_eval env df1) ; [|congruence]. - destruct (df_eval env df2) ; [|congruence]. - case_eq (df_eval_backprop_deriv env df1 gradenv - (fun (i : {n' : nat | (n' < m)%nat}) - (j : {m' : nat | (m' < n)%nat}) => grad i * d1 j)) ; [|congruence]. - intros d2. - specialize (IHdf1 (fun (i : {n' : nat | (n' < m)%nat}) - (j : {m' : nat | (m' < n)%nat}) => grad i * d1 j) gradenv d2). - specialize (IHdf2 (matrix_vector_mult - (fun (i : {n' : nat | (n' < n)%nat}) - (j : {m' : nat | (m' < m)%nat}) => d0 j i) grad) d2 d). - intros. - apply IHdf2; trivial. - apply IHdf1; trivial. - - Case "MatrixVectorAdd"%string. - intros grad gradenv d. - case_eq (df_eval_backprop_deriv env df1 gradenv grad); [|congruence]. - intros d0 casedf1. - specialize (IHdf1 _ _ _ casedf1). - clear casedf1. - revert gradenv d d0 IHdf1. - induction (bounded_seq0 n). - + simpl. - intros. - inversion H; subst. - eauto. - + intros gradenv d d0 d0eqq. - simpl. - case_eq (df_eval_backprop_deriv env df2 d0 (transpose grad a)); simpl - ; [intros ? eqq1 | intros eqq1]. - * intros. - { apply (IHl d0 _ d1); trivial. - - eapply IHdf2; eauto. - - eauto. - } - * destruct l; simpl; discriminate. - - Case "MatrixMult"%string. - intros grad gradenv d. - destruct (df_eval env df1) ; [|congruence]. - destruct (df_eval env df2) ; [|congruence]. - case_eq (df_eval_backprop_deriv env df1 gradenv - (matrix_mult grad - (fun (i : {n' : nat | (n' < n)%nat}) - (j : {m' : nat | (m' < p)%nat}) => d1 j i))) - ; [|congruence]. - intros d2. - specialize (IHdf1 (matrix_mult grad - (fun (i : {n' : nat | (n' < n)%nat}) - (j : {m' : nat | (m' < p)%nat}) => d1 j i)) - gradenv d2). - specialize (IHdf2 (matrix_mult - (fun (i : {n' : nat | (n' < p)%nat}) - (j : {m' : nat | (m' < m)%nat}) => d0 j i) grad) d2 d). - intros. - apply IHdf2; trivial. - apply IHdf1; trivial. - - Case "VectorPlus"%string. - intros grad gradenv d. - case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence]. - intros. - specialize (IHdf1 grad gradenv d0). - specialize (IHdf2 grad d0 d). - apply IHdf2. - apply H0. - apply IHdf1. - apply H. - apply H1. - - Case "VectorMinus"%string. - intros grad gradenv d. - case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence]. - intros. - specialize (IHdf1 grad gradenv d0). - specialize (IHdf2 (fun i : {n' : nat | (n' < n)%nat} => - grad i) d0 d). - apply IHdf2. - apply H0. - apply IHdf1. - apply H. - apply H1. - - Case "MatrixPlus"%string. - intros grad gradenv d. - case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence]. - intros. - specialize (IHdf1 grad gradenv d0). - specialize (IHdf2 grad d0 d). - apply IHdf2. - apply H0. - apply IHdf1. - apply H. - apply H1. - - Case "MatrixMinus"%string. - intros grad gradenv d. - case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence]. - intros. - specialize (IHdf1 grad gradenv d0). - specialize (IHdf2 (fun (i : {n' : nat | (n' < n)%nat}) - (j : {m' : nat | (m' < m)%nat}) => - grad i j) - d0 d). - apply IHdf2. - apply H0. - apply IHdf1. - apply H. - apply H1. - - Case "VectorScalMult"%string. - intros grad gradenv d. - destruct (df_eval env df1) ; [|congruence]. - destruct (df_eval env df2) ; [|congruence]. - case_eq (df_eval_backprop_deriv env df1 gradenv - (vsum (fun j : {n' : nat | (n' < n)%nat} => d1 j * grad j))) - ; [|congruence]. - intros d2. - specialize (IHdf1 (vsum (fun j : {n' : nat | (n' < n)%nat} => d1 j * grad j)) - gradenv d2). - specialize (IHdf2 (fun j : {n' : nat | (n' < n)%nat} => d0 * grad j) - d2 d). - intros. - apply IHdf2; trivial. - apply IHdf1; trivial. - - Case "MatrixScalMult"%string. - intros grad gradenv d. - destruct (df_eval env df1) ; [|congruence]. - destruct (df_eval env df2) ; [|congruence]. - case_eq (df_eval_backprop_deriv env df1 gradenv - (msum - (fun (i : {n' : nat | (n' < n)%nat}) - (j : {m' : nat | (m' < m)%nat}) => - d1 i j * grad i j))) - ; [|congruence]. - intros d2. - specialize (IHdf1 (msum - (fun (i : {n' : nat | (n' < n)%nat}) - (j : {m' : nat | (m' < m)%nat}) => - d1 i j * grad i j)) - gradenv d2). - specialize (IHdf2 (fun (i : {n' : nat | (n' < n)%nat}) - (j : {m' : nat | (m' < m)%nat}) => grad i j * d0) - d2 d). - intros. - apply IHdf2; trivial. - apply IHdf1; trivial. - - Case "VectorApply"%string. - intros grad gradenv d. - destruct (df_eval env df2) ; [|congruence]. - simpl in *. - match_destr; simpl; eauto. - - Case "MatrixApply"%string. - intros grad gradenv d. - destruct (df_eval env df2) ; [|congruence]. - match_destr; simpl; eauto. - - Case "VLossfun"%string. - intros grad gradenv d. - destruct (df_eval env df2) ; [|congruence]. - match_destr; simpl; eauto. - - Case "MLossfun"%string. - intros grad gradenv d. - destruct (df_eval env df2) ; [|congruence]. - match_destr; simpl; eauto. - Qed. - - Definition df_eval_deriv_gen_top {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v: var_type) : - option (lifted_type (definition_function_types_interp T) (snd v)) := - match (snd v) as vt return option (lifted_type (definition_function_types_interp T) vt) with - | DTfloat => df_eval_deriv_genvar σ df ((mk_env_entry (fst v, DTfloat) 1)::nil) - | DTVector n => - vectoro_to_ovector - (fun i => df_eval_deriv_genvar σ df ((mk_env_entry (fst v, DTVector n) (UnitVector n i))::nil)) - | DTMatrix n m => - matrixo_to_omatrix - (fun i j => df_eval_deriv_genvar σ df ((mk_env_entry (fst v, DTMatrix n m) (UnitMatrix n m i j))::nil)) - end. - - Program Definition subvar (x : var_type) (grad_env:df_env) := - (match snd x as y return snd x = y -> - definition_function_types_interp y -> - definition_function_types_interp y with - | DTfloat => fun pf grad => match vartlookup grad_env x with - | Some val => ((coerce _ val):float) - grad - | _ => Fopp grad - end - | DTVector n => fun pf grad => match vartlookup grad_env x with - | Some val => fun i => (((coerce _ val):Vector float n) i) - (grad i) - | _ => vmap Fopp grad - end - | DTMatrix m n => fun pf grad => match vartlookup grad_env x with - | Some val => fun i j => (((coerce _ val):Matrix float m n) i j) - (grad i j) - | _ => vmap (vmap Fopp) grad - end - end) (eq_refl _). - Next Obligation. - rewrite pf; reflexivity. - Defined. - Next Obligation. - rewrite pf; reflexivity. - Defined. - Next Obligation. - rewrite pf; reflexivity. - Defined. - - Definition df_eval_backprop_delta {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v: var_type) (grad_env:df_env) (grad: definition_function_types_interp T) : - option (definition_function_types_interp (snd v)) := - match vartlookup grad_env v with - | Some old => - lift (fun e => subvar v e old) (df_eval_backprop_deriv σ df grad_env grad) - | None => None - end. - -(* - Program Definition df_eval_backward_gen_top {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v: var_type) (grad_env:df_env) : - option (lifted_type (definition_function_types_interp (snd v)) T) := - (match T as vt return T = vt -> option (lifted_type (definition_function_types_interp (snd v)) vt) with - | DTfloat => fun pf => df_eval_backprop_delta σ df v grad_env (coerce _ 1) - | DTVector n => fun pf => - vectoro_to_ovector - (fun i => df_eval_backprop_delta σ df v grad_env (coerce _ (UnitVector n i))) - | DTMatrix m n => fun pf => - matrixo_to_omatrix - (fun i j => df_eval_backprop_delta σ df v grad_env (coerce _ (UnitMatrix m n i j))) - end) (eq_refl _). - *) - - Program Definition df_eval_backward_gen_top {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v: var_type) (grad_env:df_env) : - option (lifted_type (definition_function_types_interp (snd v)) T) := - match vartlookup grad_env v with - | Some old => - (match T as vt return T = vt -> option (lifted_type (definition_function_types_interp (snd v)) vt) with - | DTfloat => fun pf => (lift (fun e => subvar v e old) (df_eval_backprop_deriv σ df grad_env (coerce _ 1))) - | DTVector n => fun pf => - vectoro_to_ovector - (fun i => lift (fun e => subvar v e old) (df_eval_backprop_deriv σ df grad_env (coerce _ (UnitVector n i)))) - | DTMatrix m n => fun pf => - matrixo_to_omatrix - (fun i j => lift (fun e => subvar v e old) (df_eval_backprop_deriv σ df grad_env (coerce _ (UnitMatrix m n i j)))) - end) (eq_refl _) - | None => None - end. - - Definition transpose_lifted_type {T1 T2} : - lifted_type (definition_function_types_interp T1) T2 -> - lifted_type (definition_function_types_interp T2) T1 - := match T1, T2 with - | DTfloat, _ => fun inp => inp - | _, DTfloat => fun inp => inp - | DTVector n1, DTVector n2 => fun inp => fun i j => inp j i - | DTMatrix m1 n1, DTMatrix m2 n2 => fun inp => fun i j p q => inp p q i j - | DTVector n1, DTMatrix m2 n2 => fun inp => fun i p q => inp p q i - | DTMatrix m1 n1, DTVector n2 => fun inp => fun i j p => inp p i j - end. - Section deriv_deriv. - End deriv_deriv. - - Section max_derived. - Definition MaxDerived (a b : DefinedFunction UnitAnn DTfloat) := - Divide tt (Plus tt (Plus tt (Abs tt (Minus tt b a)) b) a) (Number tt 2). - - Delimit Scope df_scope with df. - - Notation "x + y" := (Plus x y) (only printing) : df_scope. - Notation "x - y" := (Minus x y) (only printing) : df_scope. - Notation "x / y" := (Divide x y) (only printing) : df_scope. - Notation "x * y" := (Times x y) (only printing) : df_scope. - Notation "x" := (Number x) (only printing, at level 0) : df_scope. - Notation "x" := (Var x) (only printing, at level 0) : df_scope. - Notation "'|' x '|'" := (Abs x) (only printing, at level 0) : df_scope. - - End max_derived. - - Section fv. - - Fixpoint df_free_variables {Ann} {T} (f : DefinedFunction Ann T) : list var_type - := match f with - | Number _ x => nil - | DVector n _ x => vlconcat_map df_free_variables x - | Constant t _ x => nil - | DMatrix n m _ x => vlconcat_map (fun a => vlconcat_map df_free_variables a) x - | Var v _ => v::nil - | Plus _ l r => (df_free_variables l) ++ (df_free_variables r) - | Minus _ l r => (df_free_variables l) ++ (df_free_variables r) - | Times _ l r => (df_free_variables l) ++ (df_free_variables r) - | Divide _ l r => (df_free_variables l) ++ (df_free_variables r) - | Max _ l r => (df_free_variables l) ++ (df_free_variables r) - | Abs _ e => df_free_variables e - | Sign _ e => df_free_variables e - | PSign _ e => df_free_variables e - | Log _ e => df_free_variables e - | Square _ e => df_free_variables e - | Exp _ e => df_free_variables e - | VectorElem n _ l i => df_free_variables l - | MatrixElem m n _ l i j => df_free_variables l - | VectorDot n _ l r => (df_free_variables l) ++ (df_free_variables r) - | VectorSum n _ l => df_free_variables l - | MatrixSum n m _ l => df_free_variables l - | VectorScalMult n _ x r => (df_free_variables x) ++ (df_free_variables r) - | MatrixScalMult n m _ x r => (df_free_variables x) ++ (df_free_variables r) - | MatrixVectorMult n m _ l r => (df_free_variables l) ++ (df_free_variables r) - | MatrixVectorAdd n m _ l r => (df_free_variables l) ++ (df_free_variables r) - | MatrixMult n m p _ l r => (df_free_variables l) ++ (df_free_variables r) - | VectorPlus n _ l r => (df_free_variables l) ++ (df_free_variables r) - | VectorMinus n _ l r => (df_free_variables l) ++ (df_free_variables r) - | MatrixPlus n m _ l r => (df_free_variables l) ++ (df_free_variables r) - | MatrixMinus n m _ l r => (df_free_variables l) ++ (df_free_variables r) - | VectorApply n _ x s l => (remove_all (x,DTfloat) (df_free_variables s)) - ++ (df_free_variables l) - | MatrixApply n m _ x s l => (remove_all (x,DTfloat) (df_free_variables s)) - ++ (df_free_variables l) - | VLossfun n _ v1 v2 s l r => (remove_all (v1,DTfloat) (remove_all (v2,DTfloat) (df_free_variables s))) - ++ (df_free_variables l) - | MLossfun n m _ v1 v2 s l r => (remove_all (v1,DTfloat) (remove_all (v2,DTfloat) (df_free_variables s))) - ++ (df_free_variables l) - end. - - Definition df_closed {Ann} {T} (f: DefinedFunction Ann T) : Prop - := match df_free_variables f with - | nil => True - | _ => False - end. - - Lemma df_closed_nil {T} (f: DefinedFunction UnitAnn T) : df_closed f -> df_free_variables f = nil. - Proof. - unfold df_closed. - destruct (df_free_variables f); tauto. - Qed. - - Definition df_closed_over {Ann} {T} (f : DefinedFunction Ann T) (vl : list var_type) : Prop - := incl (df_free_variables f) vl. - - Fixpoint fully_closed_over {Ann} {T} (df : DefinedFunction Ann T) (vl : list var_type) : Prop - := - match df with - | Number _ x => True - | DVector n _ x => vforall (fun f => fully_closed_over f vl) x - | Constant t _ x => True - | DMatrix n m _ x => vforall (fun row => - (vforall (fun f => fully_closed_over f vl) row)) x - | Var v _ => In v vl - | Plus _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | Minus _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | Times _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | Divide _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | Max _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | Abs _ e => fully_closed_over e vl - | Sign _ e => fully_closed_over e vl - | PSign _ e => fully_closed_over e vl - | Log _ e => fully_closed_over e vl - | Square _ e => fully_closed_over e vl - | Exp _ e => fully_closed_over e vl - | VectorElem n _ l i => fully_closed_over l vl - | MatrixElem m n _ l i j => fully_closed_over l vl - | VectorDot n _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | VectorSum n _ l => fully_closed_over l vl - | MatrixSum n m _ l => fully_closed_over l vl - | VectorScalMult n _ x r => (fully_closed_over x vl) /\ (fully_closed_over r vl) - | MatrixScalMult n m _ x r => (fully_closed_over x vl) /\ (fully_closed_over r vl) - | MatrixVectorMult n m _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | MatrixVectorAdd n m _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | MatrixMult n m p _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | VectorPlus n _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | VectorMinus n _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | MatrixPlus n m _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | MatrixMinus n m _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl) - | VectorApply n _ x s l => (fully_closed_over s ((x,DTfloat)::nil)) /\ - (fully_closed_over l vl) - | MatrixApply n m _ x s l => (fully_closed_over s ((x,DTfloat)::nil)) /\ - (fully_closed_over l vl) - | VLossfun n _ v1 v2 s l r => (fully_closed_over s ((v1,DTfloat)::(v2,DTfloat)::nil)) - /\ (fully_closed_over l vl) - | MLossfun n m _ v1 v2 s l r => (fully_closed_over s ((v1,DTfloat)::(v2,DTfloat)::nil)) - /\ (fully_closed_over l vl) - end. - - Definition In_compat_map (f : list var_type -> list var_type) : Prop := - forall (v : var_type) (vl : list var_type), - In v vl -> In v (f vl). - - Definition map_tl (f : list var_type -> list var_type) (vl : list var_type) := - match vl with - | a :: vl1 => a :: f vl1 - | _ => f vl - end. - - Lemma In_compat_map_tl (f : list var_type -> list var_type) : - In_compat_map f -> In_compat_map (map_tl f). - Proof. - unfold In_compat_map; intros. - destruct vl. - + now simpl. - + simpl in *. - destruct H0. - * now left. - * right; now apply H. - Qed. - - Lemma fully_closed_over_map {T} (df : DefinedFunction UnitAnn T) (vl : list var_type) (f : list var_type -> list var_type) : - In_compat_map f -> fully_closed_over df vl -> fully_closed_over df (f vl). - Proof. - revert f; revert vl. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case - ; simpl; intros; try solve [ - trivial - | - apply IHdf; trivial - | - split; destruct H0; - [apply IHdf1; trivial - | apply IHdf2; trivial] - ]. - - Case "DVector"%string. - apply vforall_forall; intros. - apply H; trivial. - now rewrite vforall_forall in H1. - - Case "DMatrix"%string. - apply vforall_forall; intros. - apply vforall_forall; intros. - apply H; trivial. - rewrite vforall_forall in H1. - specialize (H1 i). - now rewrite vforall_forall in H1. - - now apply H. - - Case "VectorApply"%string. - split; destruct H0; trivial. - now apply IHdf2. - - Case "MatrixApply"%string. - split; destruct H0; trivial. - now apply IHdf2. - - Case "VLossfun"%string. - split; destruct H0; trivial. - now apply IHdf2. - - Case "MLossfun"%string. - split; destruct H0; trivial. - now apply IHdf2. - Qed. - - (* - Lemma closed_is_fully_closed {Ann} {T} (df : DefinedFunction Ann T) (vl : list var_type) : - df_closed_over df vl <-> fully_closed_over df vl. -*) - - -(* - Lemma df_subst_nfree {T} (e: DefinedFunction T) (v:SubVar) (e':DefinedFunction DTfloat) : - ~ In v (df_free_variables e) -> - df_subst e v e' = e. - Proof. - DefinedFunction_cases (induction e) Case; simpl; trivial; intros nin - ; try solve [try rewrite in_app_iff in nin - ; intuition congruence]. - - Case "DVector"%string. - f_equal. - apply functional_extensionality. - intros x0. - apply H. - intros inn. - apply nin. - unfold vlconcat_map, vlconcat. - apply concat_In. - exists ((df_free_variables (x x0))). - split; trivial. - apply vector_to_list_In. - - - Case "DMatrix"%string. - - - Case "Var"%string. - destruct (var_dec v0 v); intuition. - Qed. - - Lemma df_eval_complete' {T} (σ:df_env) (f:DefinedFunction T) : - incl (df_free_variables f) (domain σ) -> {v | df_eval σ f = Some v}. - Proof. - induction f; simpl; intros inc - ; try solve [rewrite <- incl_app_iff in inc - ; intuition - ; destruct X as [v1 ev1] - ; destruct X0 as [v2 ev2] - ; rewrite ev1; rewrite ev2 - ; eauto - | intuition - ; destruct X as [v1 ev1] - ; rewrite ev1 - ; eauto]. - - eauto. - - apply in_dom_lookup_strong. - specialize (inc v); simpl in *. - intuition. - Qed. - - (* This version has better computational properties *) - Lemma df_eval_complete (σ:df_env) (f:DefinedFunction) : - incl (df_free_variables f) (domain σ) -> {v | df_eval σ f = Some v}. - Proof. - case_eq (df_eval σ f); simpl. - - intros r ?? ; exists r; eauto. - - intros ? inc. - destruct (df_eval_complete' _ _ inc); congruence. - Defined. - - Lemma df_eval_none (σ:df_env) (f:DefinedFunction) : - df_eval σ f = None -> - {v | In v (df_free_variables f) /\ ~ In v (domain σ)}. - Proof. - intros. - destruct (incl_dec (df_free_variables f) (domain σ)). - - destruct (df_eval_complete _ _ i); congruence. - - apply (nincl_exists) in n; trivial. - Qed. - - (* Either we can evaluate df or we are missing a variable definition. - Note that this theorem may fail to hold if we change the definition of - division to make it partial. - *) - Lemma df_eval_compute (σ:df_env) (f:DefinedFunction) : - {v | df_eval σ f = Some v} + {x | In x (df_free_variables f) /\ ~ In x (domain σ)}. - Proof. - case_eq (df_eval σ f); simpl. - - eauto. - - intros H; apply df_eval_none in H; eauto. - Defined. - - Lemma df_eval_closed (f:DefinedFunction) : - df_closed f -> {v | df_eval nil f = Some v}. - Proof. - intros c. - apply (df_eval_complete nil f). - rewrite df_closed_nil by trivial. - simpl; reflexivity. - Defined. - - Lemma df_eval_lookup_on (σ₁ σ₂:df_env) (f:DefinedFunction) : - lookup_equiv_on (df_free_variables f) σ₁ σ₂ -> - df_eval σ₁ f = df_eval σ₂ f. - Proof. - intros lookeq. - induction f; simpl in *; trivial - ; try solve [apply lookup_equiv_on_dom_app in lookeq; intuition - ; rewrite H1, H2; trivial - | rewrite IHf; trivial]. - - apply lookeq; simpl; tauto. - Qed. -*) - End fv. - - Section apply. - - Fixpoint df_apply {T} (e: DefinedFunction UnitAnn T) - (args: forall (v:var_type), DefinedFunction UnitAnn (snd v)) : DefinedFunction UnitAnn T := - match e with - | Number _ x => Number tt x - | Constant t _ x => Constant tt x - | DVector n _ df => DVector tt (fun x => df_apply (df x) args) - | DMatrix n m _ df => DMatrix tt (fun i j => df_apply (df i j) args) - | Var v _ => args v - | Plus _ l r => Plus tt (df_apply l args) (df_apply r args) - | Times _ l r => Times tt (df_apply l args) (df_apply r args) - | Minus _ l r => Minus tt (df_apply l args) (df_apply r args) - | Divide _ l r => Divide tt (df_apply l args) (df_apply r args) - | Square _ e => Square tt (df_apply e args) - | Exp _ e => Exp tt (df_apply e args) - | Log _ e => Log tt (df_apply e args) - | Abs _ e => Abs tt (df_apply e args) - | Sign _ e => Sign tt (df_apply e args) - | PSign _ e => PSign tt (df_apply e args) - | Max _ l r => Max tt (df_apply l args) (df_apply r args) - | VectorElem n _ l i => VectorElem tt (df_apply l args) i - | MatrixElem m n _ l i j => MatrixElem tt (df_apply l args) i j - | VectorDot n _ l r => VectorDot tt (df_apply l args) (df_apply r args) - | VectorSum n _ l => VectorSum tt (df_apply l args) - | MatrixSum n m _ l => MatrixSum tt (df_apply l args) - | VectorScalMult n _ x r => VectorScalMult tt (df_apply x args) (df_apply r args) - | MatrixScalMult n m _ x r => MatrixScalMult tt (df_apply x args) (df_apply r args) - | MatrixVectorMult n m _ l r => MatrixVectorMult tt (df_apply l args) (df_apply r args) - | MatrixVectorAdd n m _ l r => MatrixVectorAdd tt (df_apply l args) (df_apply r args) - | MatrixMult n m p _ l r => MatrixMult tt (df_apply l args) (df_apply r args) - | VectorPlus n _ l r => VectorPlus tt (df_apply l args) (df_apply r args) - | VectorMinus n _ l r => VectorMinus tt (df_apply l args) (df_apply r args) - | MatrixPlus n m _ l r => MatrixPlus tt (df_apply l args) (df_apply r args) - | MatrixMinus n m _ l r => MatrixMinus tt (df_apply l args) (df_apply r args) - | VectorApply n _ x s l => VectorApply tt x (df_apply s args) (df_apply l args) - | MatrixApply n m _ x s l => MatrixApply tt x (df_apply s args) (df_apply l args) - | VLossfun n _ v1 v2 s l r => VLossfun tt v1 v2 (df_apply s args) (df_apply l args) r - | MLossfun n m _ v1 v2 s l r => MLossfun tt v1 v2 (df_apply s args) (df_apply l args) r - end. - - End apply. - -End DefinedFunctions. - -Tactic Notation "DefinedFunction_cases" tactic(first) ident(c) := - first; - [ Case_aux c "Number"%string - | Case_aux c "Constant"%string - | Case_aux c "DVector"%string - | Case_aux c "DMatrix"%string - | Case_aux c "Var"%string - | Case_aux c "Plus"%string - | Case_aux c "Minus"%string - | Case_aux c "Times"%string - | Case_aux c "Divide"%string - | Case_aux c "Square"%string - | Case_aux c "Exp"%string - | Case_aux c "Log"%string - | Case_aux c "Abs"%string - | Case_aux c "Sign"%string - | Case_aux c "PSign"%string - | Case_aux c "Max"%string - | Case_aux c "VectorDot"%string - | Case_aux c "VectorSum"%string - | Case_aux c "MatrixSum"%string - | Case_aux c "VectorElem"%string - | Case_aux c "MatrixElem"%string - | Case_aux c "MatrixVectorMult"%string - | Case_aux c "MatrixVectorAdd"%string - | Case_aux c "MatrixMult"%string - | Case_aux c "VectorPlus"%string - | Case_aux c "VectorMinus"%string - | Case_aux c "MatrixPlus"%string - | Case_aux c "MatrixMinus"%string - | Case_aux c "VectorScalMult"%string - | Case_aux c "MatrixScalMult"%string - | Case_aux c "VectorApply"%string - | Case_aux c "MatrixApply"%string - | Case_aux c "VLossfun"%string - | Case_aux c "MLossfun"%string]. - -Ltac refl_simpler := - repeat - match goal with - | [H: @eq var_type _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H) - | [H: @equiv var_type _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H) - | [H: @eq definition_function_types _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H) - | [H: @equiv definition_function_types _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H) - end. - -Section real_pfs. - - Local Existing Instance floatish_R. - Import Reals. - Import List. - - Lemma MaxDerivedMax_eq (a b : DefinedFunction UnitAnn DTfloat) : - forall σ, df_eval σ (Max tt a b) = df_eval σ (MaxDerived a b). - Proof. - simpl; intros σ. - destruct (df_eval σ a); destruct (df_eval σ b); trivial. - f_equal. - autorewrite with Rarith in *. - destruct (Rle_dec d d0). - - rewrite Rmax_right by trivial. - rewrite Rabs_pos_eq by lra. - lra. - - rewrite Rmax_left by lra. - rewrite Rabs_minus_sym. - rewrite Rabs_pos_eq by lra. - lra. - Qed. - -(* Lemma coerce_dec_id {A} (dec:forall x y:A, {x=y}+{x<>y}) (x:A) (pf:x=x) : coerce pf x = x. - Proof. - unfold coerce. - replace pf with (eq_refl A); trivial. - apply UIP_dec. - apply dec. - generalize (@UIP_dec A dec pf). - Lemma var_type_UIP_refl {x:var_type} (e:x=x) : e = eq_refl x. - Proof. - apply (UIP_dec vart_dec x x pf). - Qed. - - unfold coerce. - destruct pf. - destruct pf. - exact a. - Defined. -*) - -Tactic Notation "DefinedFunction_scalar_cases" tactic(first) ident(c) := - first; - [ Case_aux c "Number"%string - | Case_aux c "Constant"%string - | Case_aux c "Var"%string - | Case_aux c "Plus"%string - | Case_aux c "Minus"%string - | Case_aux c "Times"%string - | Case_aux c "Divide"%string - | Case_aux c "Square"%string - | Case_aux c "Exp"%string - | Case_aux c "Log"%string - | Case_aux c "Abs"%string - | Case_aux c "Sign"%string - | Case_aux c "PSign"%string - | Case_aux c "Max"%string]. - - - Lemma backpropeq_gen (x : SubVar) (env gradenv : df_env) (dfexpr : DefinedFunction UnitAnn DTfloat) (grad : float) : - let xvar := (x, DTfloat) in - is_scalar_function dfexpr -> - vartlookup gradenv (x,DTfloat) <> None -> - match df_eval_deriv env dfexpr xvar, - backprop_lookup (Some gradenv) xvar, - backprop_lookup (df_eval_backprop_deriv env dfexpr gradenv grad) xvar - with - | Some dval, Some bval0, Some bval1 => (dval*grad + bval0)%R = bval1 - | None, _, None => True - | _, _, _ => False - end. - Proof. - simpl. - intros is_scalar. - generalize is_scalar. - revert grad gradenv. - pattern dfexpr. - revert dfexpr is_scalar. - DefinedFunction_scalar_cases (apply is_scalar_function_ind) Case; simpl. - - Case "Number"%string. - intros _ _ grad gradenv _ xinn. - destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra. - - Case "Constant"%string. - intros _ _ grad gradenv _ xinn. - destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra. - - Case "Var"%string. - intros sv _ grad gradenv _ xinn. - case_eq (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]. - destruct (var_dec x sv); simpl. - + subst. - rewrite H; simpl. - rewrite lookup_update. - destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (sv, DTfloat)); [| congruence]. - unfold addvar; simpl. - rewrite H. - lra. - + destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (x, DTfloat)); [congruence | ]. - case_eq (vartlookup gradenv (sv, DTfloat)); simpl; intros. - * rewrite lookup_update_neq by congruence. - rewrite H. - lra. - * rewrite H. - lra. - - Case "Plus"%string. - intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn. - case_eq (df_eval_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl grad gradenv isc1 xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_backprop_deriv env l gradenv grad) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr grad ge' isc2). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_backprop_deriv env r ge' grad) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_backprop_deriv env l gradenv grad ); simpl; trivial; intros. - apply IHr; trivial. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl grad gradenv isc1 xinn). - case_eq (df_eval_backprop_deriv env l gradenv grad); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Minus"%string. - intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn. - case_eq (df_eval_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl grad gradenv isc1 xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_backprop_deriv env l gradenv grad) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr (- grad)%R ge' isc2). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_backprop_deriv env r ge' (- grad)%R) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_backprop_deriv env l gradenv grad ); simpl; trivial; intros. - apply IHr; trivial. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl grad gradenv isc1 xinn). - case_eq (df_eval_backprop_deriv env l gradenv grad); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Times"%string. - intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn. - case_eq (df_eval env l); - [ intros le eqle | intros eqle]; simpl; trivial. - case_eq (df_eval_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval env r); - [ intros re eqre | intros eqre] - ; simpl; trivial. - case_eq (df_eval_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl (re * grad)%R gradenv isc1 xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_backprop_deriv env l gradenv (re * grad)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr (le * grad)%R ge' isc2). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_backprop_deriv env r ge' (le * grad)%R) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_backprop_deriv env l gradenv (re * grad)%R ); simpl; trivial; intros. - apply IHr; trivial. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + case_eq (df_eval env r); - [ intros re eqre | intros eqre] - ; simpl; trivial. - specialize (IHl (re * grad)%R gradenv isc1 xinn). - case_eq (df_eval_backprop_deriv env l gradenv (re * grad)%R); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Divide"%string. - intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn. - case_eq (df_eval env l); - [ intros le eqle | intros eqle]; simpl; trivial. - case_eq (df_eval_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval env r); - [ intros re eqre | intros eqre] - ; simpl; trivial. - case_eq (df_eval_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl (grad / re)%R gradenv isc1 xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_backprop_deriv env l gradenv (grad / re)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr (- le / (re * re) * grad)%R ge' isc2). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_backprop_deriv env r ge' (- le / (re * re) * grad)%R) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_backprop_deriv env l gradenv (grad / re)%R ); simpl; trivial; intros. - apply IHr; trivial. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + case_eq (df_eval env r); - [ intros re eqre | intros eqre] - ; simpl; trivial. - specialize (IHl (grad / re)%R gradenv isc1 xinn). - case_eq (df_eval_backprop_deriv env l gradenv (grad / re)%R); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Square"%string. - intros _ e IHe grad gradenv isc xinn. - case_eq (df_eval env e); - [ intros le eqee | intros eqee]; simpl; trivial. - - specialize (IHe (2 * le * grad)%R gradenv isc xinn). - case_eq (df_eval_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_backprop_deriv env e gradenv (2 * le * grad)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - Case "Exp"%string. - intros _ e IHe grad gradenv isc xinn. - case_eq (df_eval env e); - [ intros le eqee | intros eqee]; simpl; trivial. - - specialize (IHe (grad * exp le)%R gradenv isc xinn). - case_eq (df_eval_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_backprop_deriv env e gradenv (grad * exp le)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - Case "Log"%string. - intros _ e IHe grad gradenv isc xinn. - case_eq (df_eval env e); - [ intros le eqee | intros eqee]; simpl; trivial. - - specialize (IHe (grad / le)%R gradenv isc xinn). - case_eq (df_eval_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_backprop_deriv env e gradenv (grad / le)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - Case "Abs"%string. - intros _ e IHe grad gradenv isc xinn. - case_eq (df_eval env e); - [ intros le eqee | intros eqee]; simpl; trivial. - - specialize (IHe (grad * (sign le))%R gradenv isc xinn). - case_eq (df_eval_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_backprop_deriv env e gradenv (grad * (sign le))%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - Case "Sign"%string. - intros _ e IHe grad gradenv isc xinn. - specialize (IHe 0%R gradenv isc xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (df_eval_backprop_deriv env e gradenv 0%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - simpl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - lra. - - Case "PSign"%string. - intros _ e IHe grad gradenv isc xinn. - specialize (IHe 0%R gradenv isc xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (df_eval_backprop_deriv env e gradenv 0%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - simpl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - lra. - - Case "Max"%string. - intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn. - specialize (IHl grad gradenv isc1 xinn). - specialize (IHr grad gradenv isc2 xinn). - - case_eq (df_eval env l); simpl; trivial - ; intros eld eqeld. - case_eq (df_eval env r ); simpl; intros; trivial. - destruct (Rle_dec eld d); simpl. - + destruct (df_eval_deriv env r (x, DTfloat)); simpl; trivial. - + destruct (df_eval_deriv env l (x, DTfloat)); simpl; trivial. - Qed. - - (* - - Lemma tree_backpropeq_gen (x : SubVar) (env gradenv : df_env) - (dfexpr : DefinedFunction EvalAnn DTfloat) (grad : float) : - let xvar := (x, DTfloat) in - is_scalar_function dfexpr -> - vartlookup gradenv (x,DTfloat) <> None -> - match df_eval_tree_deriv env dfexpr xvar, - backprop_lookup (Some gradenv) xvar, - backprop_lookup (df_eval_tree_backprop_deriv env dfexpr gradenv grad) xvar - with - | Some dval, Some bval0, Some bval1 => (dval*grad + bval0)%R = bval1 - | None, _, None => True - | _, _, _ => False - end. - Proof. - simpl. - intros is_scalar. - revert grad gradenv. - pattern dfexpr. - revert dfexpr is_scalar. - DefinedFunction_scalar_cases (apply is_scalar_function_ind) Case; simpl. - - - Case "Number"%string. - intros _ _ grad gradenv xinn. - destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra. - - Case "Constant"%string. - intros _ _ grad gradenv xinn. - destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra. - - Case "Var"%string. - intros sv _ grad gradenv xinn. - case_eq (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]. - destruct (var_dec x sv); simpl. - + subst. - rewrite H; simpl. - rewrite lookup_update. - destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (sv, DTfloat)); [| congruence]. - unfold addvar; simpl. - rewrite H. - lra. - + destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (x, DTfloat)); [congruence | ]. - case_eq (vartlookup gradenv (sv, DTfloat)); simpl; intros. - * rewrite lookup_update_neq by congruence. - rewrite H. - lra. - * rewrite H. - lra. - - Case "Plus"%string. - intros _ l r IHl IHr grad gradenv xinn. - case_eq (df_eval_tree_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_tree_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl grad gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env l gradenv grad) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr grad ge'). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' grad) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_tree_backprop_deriv env l gradenv grad ); simpl; trivial; intros. - apply IHr. - apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl grad gradenv xinn). - case_eq (df_eval_tree_backprop_deriv env l gradenv grad); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Minus"%string. - intros _ l r IHl IHr grad gradenv xinn. - case_eq (df_eval_tree_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_tree_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl grad gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env l gradenv grad) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr (- grad)%R ge'). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (- grad)%R) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_tree_backprop_deriv env l gradenv grad ); simpl; trivial; intros. - apply IHr. - apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl grad gradenv xinn). - case_eq (df_eval_tree_backprop_deriv env l gradenv grad); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Times"%string. - intros _ l r IHl IHr grad gradenv xinn. - case_eq (df_eval_tree_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_tree_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl (get_annotation r * grad)%R gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr (get_annotation l * grad)%R ge'). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (get_annotation l * grad)%R) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R ); simpl; trivial; intros. - apply IHr. - apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl (get_annotation r * grad)%R gradenv xinn). - case_eq (df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Divide"%string. - intros _ l r IHl IHr grad gradenv xinn. - case_eq (df_eval_tree_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_tree_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl (grad / get_annotation r)%R gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr (- get_annotation l / ((get_annotation r) * (get_annotation r)) * grad)%R ge'). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (- get_annotation l / ((get_annotation r) * (get_annotation r)) * grad)%R) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r)%R ); simpl; trivial; intros. - apply IHr. - apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl (grad / get_annotation r)%R gradenv xinn). - case_eq (df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r )%R); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Square"%string. - intros _ e IHe grad gradenv xinn. - - specialize (IHe (2 * (get_annotation e) * grad)%R gradenv xinn). - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env e gradenv (2 * (get_annotation e) * grad)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - Case "Exp"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe (grad * exp (get_annotation e))%R gradenv xinn). - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env e gradenv (grad * exp (get_annotation e))%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - - Case "Log"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe (grad / get_annotation e)%R gradenv xinn). - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env e gradenv (grad / get_annotation e)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - - Case "Abs"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe (grad * (sign (get_annotation e)))%R gradenv xinn). - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env e gradenv (grad * (sign (get_annotation e)))%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - Case "Sign"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe 0%R gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (df_eval_tree_backprop_deriv env e gradenv 0%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - simpl. - replace (de * 0)%R with (0)%R in IHe by lra. - replace (0 * grad)%R with (0)%R by lra. - apply IHe. - - Case "PSign"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe 0%R gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (df_eval_tree_backprop_deriv env e gradenv 0%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - simpl. - replace (de * 0)%R with (0)%R in IHe by lra. - replace (0 * grad)%R with (0)%R by lra. - apply IHe. - - Case "Max"%string. - intros _ l r IHl IHr grad gradenv xinn. - specialize (IHl grad gradenv xinn). - specialize (IHr grad gradenv xinn). - destruct (Rle_dec (get_annotation l) (get_annotation r)); simpl. - destruct (df_eval_tree_deriv env r (x, DTfloat)); simpl; trivial. - destruct (df_eval_tree_deriv env l (x, DTfloat)); simpl; trivial. - Qed. - *) - - Lemma eval_fully_closed_not_none {T} (σ:df_env) (df:DefinedFunction UnitAnn T) : - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> df_eval σ df <> None. - Proof. - revert σ. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case - ; simpl; intros; - try solve - [congruence - | - destruct H; simpl in *; - specialize (IHdf1 σ); specialize (IHdf2 σ); - match_option; [|tauto]; - cut_to IHdf2; - [ match_option; tauto | easy] - | - specialize (IHdf σ); simpl in IHdf; - cut_to IHdf; - [ match_option; tauto | easy] - ]. - - Case "DVector"%string. - apply vectoro_to_ovector_not_none; intro. - specialize (H i σ); simpl in H; apply H. - rewrite vforall_forall in H0. - now specialize (H0 i). - - Case "DMatrix"%string. - unfold matrixo_to_omatrix. - apply vectoro_to_ovector_not_none; intro. - apply vectoro_to_ovector_not_none; intro. - specialize (H i i0 σ); simpl in H; apply H. - rewrite vforall_forall in H0; specialize (H0 i). - rewrite vforall_forall in H0; now specialize (H0 i0). - - Case "Var"%string. - induction σ. - + simpl in H; tauto. - + simpl in *. - match_case; intros. - destruct H. - * congruence. - * now apply IHσ. - - Case "VectorApply"%string. - destruct H; simpl in *. - specialize (IHdf2 σ). - cut_to IHdf2; trivial. - match_option; [|tauto]. - apply vectoro_to_ovector_not_none; intro. - specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i) :: nil)). - now apply IHdf1. - - Case "MatrixApply"%string. - destruct H; simpl in *. - specialize (IHdf2 σ). - cut_to IHdf2; trivial. - match_option; [|tauto]. - unfold matrixo_to_omatrix. - apply vectoro_to_ovector_not_none; intro. - apply vectoro_to_ovector_not_none; intro. - specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i i0) :: nil)). - now apply IHdf1. - - Case "VLossfun"%string. - destruct H; simpl in *. - specialize (IHdf2 σ). - cut_to IHdf2; trivial. - match_option; [|tauto]. - match_option. - apply vectoro_to_ovector_not_none in eqq0. - + tauto. - + intros. - specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i) :: mk_env_entry (v2, DTfloat) (r i) :: nil)). - now apply IHdf1. - - Case "MLossfun"%string. - destruct H; simpl in *. - specialize (IHdf2 σ). - cut_to IHdf2; trivial. - match_option; [|tauto]. - unfold matrixo_to_omatrix. - match_option. - apply vectoro_to_ovector_not_none in eqq0. - + tauto. - + intros. - apply vectoro_to_ovector_not_none; intros. - specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i i0) :: mk_env_entry (v2, DTfloat) (r i i0) :: nil)). - now apply IHdf1. - Qed. - - Lemma eval_fully_closed_total {T} (σ:df_env) (df:DefinedFunction UnitAnn T) : - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> - {d:definition_function_types_interp T | df_eval σ df = Some d}. - Proof. - intros. - case_eq (df_eval σ df); intros. - - now exists d. - - generalize (eval_fully_closed_not_none σ df). - intros; simpl in *. - cut_to H1; tauto. - Qed. - - Lemma closed_over_cons {T} (df:DefinedFunction UnitAnn T) (v:var_type) (vl : list var_type): - df_closed_over df vl -> df_closed_over df (v::vl). - Proof. - unfold df_closed_over. - intros. - apply incl_tl. - apply H. - Qed. - - Lemma fully_closed_over_cons {T} (df:DefinedFunction UnitAnn T) (v:var_type) - (vl : list var_type): - fully_closed_over df vl -> fully_closed_over df (v::vl). - Proof. - intros. - apply (fully_closed_over_map df vl (fun vl1 => cons v vl1)); trivial. - unfold In_compat_map. - intros. - now apply in_cons. - Qed. - - Lemma fully_closed_over_exchange_vars {T} (df:DefinedFunction UnitAnn T) (v1 v:var_type) - (vl : list var_type): - fully_closed_over df (v1 :: v :: vl) -> fully_closed_over df (v :: v1 :: vl). - Proof. - intros. - apply (fully_closed_over_map df (v1 :: v :: vl) - (fun vl1 => match vl1 with - | a :: b :: vl2 => b :: a :: vl2 - | _ => vl1 - end )); trivial. - unfold In_compat_map. - intros. - destruct vl0; trivial. - destruct vl0; trivial. - unfold In. - unfold In in H0. - tauto. - Qed. - - Lemma fully_closed_over_singleton {T} (df:DefinedFunction UnitAnn T) (v:var_type) - (vl : list var_type): - fully_closed_over df (v::nil) -> fully_closed_over df (v::vl). - Proof. - intros. - induction vl; trivial. - apply fully_closed_over_exchange_vars. - now apply fully_closed_over_cons. - Qed. - - Lemma fully_closed_over_exchange_2vars {T} (df:DefinedFunction UnitAnn T) - (v1 v2 v:var_type) (vl : list var_type): - fully_closed_over df (v1 :: v2 :: v:: vl) -> fully_closed_over df (v :: v1 :: v2 :: vl). - Proof. - intros. - apply (fully_closed_over_map df (v1 :: v2 :: v :: vl) - (fun vl1 => match vl1 with - | a :: b :: c :: vl2 => c :: a :: b :: vl2 - | _ => vl1 - end )); trivial. - unfold In_compat_map. - intros. - destruct vl0; trivial. - destruct vl0; trivial. - destruct vl0; trivial. - unfold In. - unfold In in H0. - tauto. - Qed. - - Lemma fully_closed_over_pair {T} (df:DefinedFunction UnitAnn T) (v1 v2:var_type) - (vl : list var_type): - fully_closed_over df (v1::v2::nil) -> fully_closed_over df (v1::v2::vl). - Proof. - intros. - induction vl; trivial. - apply fully_closed_over_exchange_2vars. - apply fully_closed_over_exchange_2vars. - now apply fully_closed_over_cons. - Qed. - - Lemma fully_closed_subst {T} (vl:list var_type) (df:DefinedFunction UnitAnn T) (v:var_type) - (e':DefinedFunction UnitAnn (snd v)): - fully_closed_over df (v::vl) -> - fully_closed_over e' vl -> - fully_closed_over (df_subst df v e') vl. - Proof. - revert vl. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case - ; simpl; intros; try solve - [easy - | - apply IHdf; [apply H | apply H0] - | - split; destruct H; simpl in *; - [ apply IHdf1; [apply H | apply H0] - | apply IHdf2; [apply H1 | apply H0]] - ]. - - Case "DVector"%string. - apply vforall_forall; intros; simpl in *. - specialize (H i); apply H. - rewrite vforall_forall in H0. - specialize (H0 i). - apply H0. - apply H1. - - Case "DMatrix"%string. - apply vforall_forall; intros. - apply vforall_forall; intros; simpl in *. - specialize (H i i0); apply H. - rewrite vforall_forall in H0; specialize (H0 i). - rewrite vforall_forall in H0; specialize (H0 i0). - apply H0. - apply H1. - - Case "Var"%string. - unfold substvar. - destruct H. - + subst. - unfold substvar. - match_destr; [ | congruence]. - refl_simpler. - simpl; trivial. - + destruct v; destruct v0. - simpl in *. - match_destr. - red in e; subst. - simpl; trivial. - - Case "VectorApply"%string. - destruct H; split; trivial. - + apply IHdf2. - * apply H1. - * apply H0. - - Case "MatrixApply"%string. - destruct H; split; trivial. - + apply IHdf2. - * apply H1. - * apply H0. - - Case "VLossfun"%string. - destruct H; split; trivial. - + apply IHdf2. - * apply H1. - * apply H0. - - Case "MLossfun"%string. - destruct H; split; trivial. - + apply IHdf2. - * apply H1. - * apply H0. - Qed. - - Lemma fully_closed_deriv {T} (df:DefinedFunction UnitAnn T) (s:SubVar) - (vl : list var_type): - fully_closed_over df vl -> - fully_closed_over (df_deriv df (s, DTfloat)) vl. - Proof. - revert s; revert vl. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case - ; simpl; intros; try solve - [easy - | - apply IHdf,H - | - split; try easy; now apply IHdf - | - destruct H; repeat split; try easy; - [apply IHdf2, H0 | apply IHdf1, H] - | - destruct H; repeat split; try easy; - [ apply IHdf1, H | apply IHdf2, H0] - ]. - - Case "DVector"%string. - apply vforall_forall; intros. - apply H. - rewrite vforall_forall in H0. - now apply H0. - - Case "DMatrix"%string. - apply vforall_forall; intros. - apply vforall_forall; intros. - apply H. - rewrite vforall_forall in H0. - specialize (H0 i). - rewrite vforall_forall in H0. - now specialize (H0 i0). - - Case "Max"%string. - destruct H; repeat split; try easy. - apply IHdf2, H0. - apply IHdf1, H. - apply IHdf2, H0. - apply IHdf1, H. - - Case "VectorApply"%string. - apply vforall_forall; intros; simpl in *. - split; destruct H. - apply IHdf2, H0. - apply fully_closed_subst. - + apply IHdf1. - now apply fully_closed_over_singleton. - + simpl; apply H0. - - Case "MatrixApply"%string. - apply vforall_forall; intros; simpl in *. - apply vforall_forall; intros; simpl in *. - split; destruct H. - apply IHdf2, H0. - apply fully_closed_subst. - + apply IHdf1. - now apply fully_closed_over_singleton. - + simpl; apply H0. - - Case "VLossfun"%string. - intros; simpl in *. - split; destruct H. - apply IHdf2, H0. - apply vforall_forall; intros; simpl in *. - apply fully_closed_subst. - apply fully_closed_subst. - + apply IHdf1. - now apply fully_closed_over_pair. - + simpl; apply fully_closed_over_cons; apply H0. - + now simpl. - - Case "MLossfun"%string. - intros; simpl in *. - apply vforall_forall; intros; simpl in *. - apply vforall_forall; intros; simpl in *. - destruct H; split; trivial. - split. - apply IHdf2, H0. - apply fully_closed_subst. - apply fully_closed_subst. - + apply IHdf1. - now apply fully_closed_over_pair. - + simpl; apply fully_closed_over_cons; apply H0. - + now simpl. - Qed. - - Lemma list_env_iter_total_fun {A} (f : A -> df_env -> option df_env) (env : df_env) (l : list A) : - (forall (a:A) (env0: df_env), (f a env0) <> None) -> - list_env_iter f (Some env) l <> None. - Proof. - intros. - generalize env. - induction l; [simpl; congruence|]. - simpl; intros. - specialize (H a env0). - case_eq (f a env0). - - intros; apply (IHl d). - - tauto. - Qed. - - Lemma backprop_deriv_fully_closed_not_none {T} (σ:df_env) (df:DefinedFunction UnitAnn T) - (grad_env:df_env) (grad: definition_function_types_interp T): - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> df_eval_backprop_deriv σ df grad_env grad <> None. - Proof. - revert grad_env. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case - ; simpl; intros; - try solve [congruence - | - specialize (IHdf1 grad grad_env); simpl in IHdf1; destruct H; - match_option; [|tauto]; - specialize (IHdf2 grad d); simpl in IHdf2; - now apply IHdf2 - ]. - - Case "DVector"%string. - unfold two_vector_env_iter_alt. - rewrite vforall_forall in H0. - apply (list_env_iter_total_fun - (fun i env => df_eval_backprop_deriv σ (x i) env (grad i)) - grad_env (bounded_seq0 n)). - intros. - apply (H a (grad a) env0). - apply (H0 a). - - Case "DMatrix"%string. - unfold two_matrix_env_iter_alt. - rewrite vforall_forall in H0. - apply (list_env_iter_total_fun - (fun i env => - list_env_iter - (fun j env0 => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (Some env) (bounded_seq0 m)) - grad_env (bounded_seq0 n)). - intros. - apply (list_env_iter_total_fun - (fun j env => df_eval_backprop_deriv σ (x a j) env (grad a j)) - env0 (bounded_seq0 m)). - intros. - apply (H a a0 (grad a a0) env1). - specialize (H0 a). - rewrite vforall_forall in H0. - apply (H0 a0). - - Case "Var"%string. - match_destr. - - Case "Minus"%string. - specialize (IHdf1 grad grad_env); simpl in IHdf1; destruct H; - match_option; [|tauto]; - specialize (IHdf2 (-grad)%R d); simpl in IHdf2; - now apply IHdf2. - - Case "Times"%string. - destruct H. - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2. - + match_option; [|tauto]. - specialize (IHdf1 (d0 * grad)%R grad_env); simpl in IHdf1. - match_option; [|tauto]. - specialize (IHdf2 (d * grad)%R d1); simpl in IHdf2. - now apply IHdf2. - - Case "Divide"%string. - destruct H. - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2. - + match_option; [|tauto]. - specialize (IHdf1 (grad / d0)%R grad_env); simpl in IHdf1. - match_option; [|tauto]. - specialize (IHdf2 (-d / (d0 * d0) * grad)%R d1); simpl in IHdf2. - now apply IHdf2. - - Case "Square"%string. - generalize (eval_fully_closed_not_none σ df); intros; simpl in H0. - match_option; [|tauto]. - specialize (IHdf (2 * d * grad)%R grad_env); simpl in IHdf. - now apply IHdf. - - Case "Exp"%string. - generalize (eval_fully_closed_not_none σ df); intros; simpl in H0. - match_option; [|tauto]. - specialize (IHdf (grad * exp d)%R grad_env); simpl in IHdf. - now apply IHdf. - - Case "Log"%string. - generalize (eval_fully_closed_not_none σ df); intros; simpl in H0. - match_option; [|tauto]. - specialize (IHdf (grad / d)%R grad_env); simpl in IHdf. - now apply IHdf. - - Case "Abs"%string. - generalize (eval_fully_closed_not_none σ df); intros; simpl in H0. - match_option; [|tauto]. - specialize (IHdf (grad * sign d)%R grad_env); simpl in IHdf. - now apply IHdf. - - Case "Sign"%string. - specialize (IHdf (0)%R grad_env); simpl in IHdf. - now apply IHdf. - - Case "PSign"%string. - specialize (IHdf (0)%R grad_env); simpl in IHdf. - now apply IHdf. - - Case "Max"%string. - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - match_option; [|tauto]. - match_destr. - specialize (IHdf2 grad grad_env). - now apply IHdf2. - specialize (IHdf1 grad grad_env). - now apply IHdf1. - - Case "VectorDot"%string. - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - match_option; [|tauto]. - specialize (IHdf1 (vmap (fun rv : R => (rv * grad)%R) d0) grad_env); simpl in IHdf1. - match_option; [|tauto]. - specialize (IHdf2 (vmap (fun lv : R => (lv * grad)%R) d) d1). - now apply IHdf2. - - Case "VectorSum"%string. - specialize (IHdf (ConstVector n grad) grad_env). - now apply IHdf. - - Case "MatrixSum"%string. - specialize (IHdf (ConstMatrix m n grad) grad_env). - now apply IHdf. - - Case "VectorElem"%string. - specialize (IHdf (fun k : {n' : nat | n' < n} => if equiv_dec (` k) (` i) then grad else 0%R) grad_env). - now apply IHdf. - - Case "MatrixElem"%string. - specialize (IHdf (fun (k1 : {n' : nat | n' < m}) - (k2 : {m' : nat | m' < n}) => - if equiv_dec (` k1) (` i) then if - equiv_dec (` k2) (` j) then grad else 0%R else 0%R) - grad_env). - now apply IHdf. - - Case "MatrixVectorMult"%string. - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - match_option; [|tauto]. - specialize (IHdf1 (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => (grad i * d0 j)%R) - grad_env); simpl in IHdf1. - match_option; [|tauto]. - specialize (IHdf2 (matrix_vector_mult - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => d j i) - grad) - d1). - now apply IHdf2. - - Case "MatrixVectorAdd"%string. - specialize (IHdf1 grad grad_env); simpl in IHdf1. - match_option; [|tauto]. - match_option. - generalize (list_env_iter_total_fun - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (@transpose R m n grad i)) - d (bounded_seq0 n)); intros. - cut_to H0; [congruence|]. - intros; destruct H. - now apply IHdf2. - - Case "MatrixMult"%string. - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - match_option; [|tauto]. - specialize (IHdf1 (matrix_mult grad (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < p}) => d0 j i)) - grad_env); simpl in IHdf1. - match_option; [|tauto]. - specialize (IHdf2 (matrix_mult (fun (i : {n' : nat | n' < p}) - (j : {m' : nat | m' < m}) => d j i) grad) - d1). - now apply IHdf2. - - Case "VectorMinus"%string. - specialize (IHdf1 grad grad_env); simpl in IHdf1. - match_option; [|tauto]. - specialize (IHdf2 (fun i : {n' : nat | n' < n} => (- grad i)%R) d). - now apply IHdf2. - - Case "MatrixMinus"%string. - specialize (IHdf1 grad grad_env); simpl in IHdf1. - match_option; [|tauto]. - specialize (IHdf2 (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (- grad i j)%R) - d). - now apply IHdf2. - - Case "VectorScalMult"%string. - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - match_option; [|tauto]. - specialize (IHdf1 (vsum (fun j : {n' : nat | n' < n} => (d0 j * grad j)%R)) - grad_env); simpl in IHdf1. - match_option; [|tauto]. - specialize (IHdf2 (fun j : {n' : nat | n' < n} => (d * grad j)%R) d1). - now apply IHdf2. - - Case "MatrixScalMult"%string. - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - match_option; [|tauto]. - specialize (IHdf1 (msum - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (d0 i j * grad i j)%R)) - grad_env); simpl in IHdf1. - match_option; [|tauto]. - specialize (IHdf2 (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad i j * d)%R) - d1). - now apply IHdf2. - - Case "VectorApply"%string. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H0. - match_option; [|tauto]. - match_option. - + specialize (IHdf2 v0 grad_env). - now apply IHdf2. - + apply vectoro_to_ovector_exists_None in eqq0. - destruct eqq0. - rewrite vmap_nth in e; simpl in e. - destruct H. - match_option_in e. - generalize (fully_closed_deriv df1 v ((v,DTfloat):: nil)); intros. - cut_to H2; trivial. - generalize (eval_fully_closed_not_none (mk_env_entry (v, DTfloat) (d x) :: nil) - (df_deriv df1 (v, DTfloat))); intros. - simpl in H3; cut_to H3; tauto. - - Case "MatrixApply"%string. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H0. - match_option; [|tauto]. - match_option. - + specialize (IHdf2 m0 grad_env). - now apply IHdf2. - + unfold matrixo_to_omatrix in eqq0. - apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0. - apply vectoro_to_ovector_exists_None in e; destruct e. - unfold mmap in e. - rewrite vmap_nth in e; simpl in e. - rewrite vmap_nth in e; simpl in e. - destruct H. - unfold matrix_zip in e. - rewrite vmap_nth in e; simpl in e. - match_option_in e. - generalize (fully_closed_deriv df1 v ((v,DTfloat):: nil)). - intros; cut_to H2; trivial. - generalize (eval_fully_closed_not_none (mk_env_entry (v, DTfloat) (d x x0) :: nil) - (df_deriv df1 (v, DTfloat))); intros. - simpl in H3; cut_to H3; tauto. - - Case "VLossfun"%string. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H0. - match_option; [|tauto]. - match_option. - + specialize (IHdf2 v grad_env). - now apply IHdf2. - + apply vectoro_to_ovector_exists_None in eqq0. - destruct eqq0. - rewrite vmap_nth in e; simpl in e. - destruct H. - match_option_in e. - generalize (fully_closed_deriv df1 v1 ((v1,DTfloat)::(v2,DTfloat)::nil)). - intros; cut_to H2; trivial. - generalize (eval_fully_closed_not_none (mk_env_entry (v1, DTfloat) (d x) :: - mk_env_entry (v2, DTfloat) (r x) :: nil) - (df_deriv df1 (v1, DTfloat))); intros. - simpl in H3; cut_to H3; tauto. - - Case "MLossfun"%string. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H0. - match_option; [|tauto]. - match_option. - + specialize (IHdf2 m0 grad_env). - now apply IHdf2. - + unfold matrixo_to_omatrix in eqq0. - apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0. - apply vectoro_to_ovector_exists_None in e; destruct e. - unfold mmap in e. - rewrite vmap_nth in e; simpl in e. - rewrite vmap_nth in e; simpl in e. - destruct H. - unfold matrix_zip in e. - rewrite vmap_nth in e; simpl in e. - match_option_in e. - generalize (fully_closed_deriv df1 v1 ((v1,DTfloat)::(v2,DTfloat)::nil)). - intros; cut_to H2; trivial. - generalize (eval_fully_closed_not_none (mk_env_entry (v1, DTfloat) (d x x0) :: - mk_env_entry (v2, DTfloat) (r x x0) :: nil) - (df_deriv df1 (v1, DTfloat))); intros. - simpl in H3; cut_to H3; tauto. - Qed. - - Lemma backprop_deriv_fully_closed_total {T} (σ:df_env) (df:DefinedFunction UnitAnn T) - (grad_env:df_env) (grad: definition_function_types_interp T): - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> - {d:df_env | df_eval_backprop_deriv σ df grad_env grad = Some d}. - Proof. - case_eq (df_eval_backprop_deriv σ df grad_env grad); intros. - - now exists d. - - generalize (backprop_deriv_fully_closed_not_none σ df grad_env grad). - intros; simpl in *. - cut_to H1; tauto. - Qed. - - Lemma eval_deriv_fully_closed_not_none {T} (σ:df_env) (df:DefinedFunction UnitAnn T) - (v:var_type): - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> df_eval_deriv σ df v <> None. - Proof. - revert σ v. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case - ; intros; simpl in *; - try solve [ - congruence - | - destruct H - ; specialize (IHdf1 σ v); specialize (IHdf2 σ v) - ;cut_to IHdf1; trivial - ;match_option; [|tauto] - ;cut_to IHdf2; trivial - ;match_option; tauto - | - destruct H; - specialize (IHdf1 σ v); specialize (IHdf2 σ v); - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1; - cut_to H1; trivial; - match_option; [|tauto]; - cut_to IHdf1; trivial; - match_option; [|tauto]; - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2; - match_option; [|tauto]; - cut_to IHdf2; trivial; - match_option; tauto - | - generalize (eval_fully_closed_not_none σ df); intros; - specialize (IHdf σ v); - simpl in H0; cut_to H0; trivial; - match_option; [|tauto]; - cut_to IHdf; trivial; - match_option; tauto - | - specialize (IHdf σ v); - generalize (eval_fully_closed_not_none σ df); intros; - simpl in H0; cut_to H0; trivial; - match_option; tauto - ]. - - Case "DVector"%string. - apply vectoro_to_ovector_not_none; intros; apply H. - rewrite vforall_forall in H0; apply H0. - - Case "DMatrix"%string. - apply vectoro_to_ovector_not_none; intros. - apply vectoro_to_ovector_not_none; intros; apply H. - rewrite vforall_forall in H0; specialize (H0 i). - rewrite vforall_forall in H0; apply H0. - - Case "Max"%string. - destruct H; - specialize (IHdf1 σ v); specialize (IHdf2 σ v); - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1; - cut_to H1; trivial; - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2; - match_option; [|tauto]. - case_eq ( Rle_dec d d0 ); intros. - cut_to IHdf2; trivial. - cut_to IHdf1; trivial. - - Case "VectorApply"%string. - destruct H. - specialize (IHdf2 σ v0); - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - cut_to H1; trivial. - match_option; [|tauto]. - cut_to IHdf2; trivial. - match_option; [|tauto]. - apply vectoro_to_ovector_not_none; intros. - match_option. - specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i) :: nil) (v, DTfloat)). - cut_to IHdf1; trivial. - tauto. - - Case "MatrixApply"%string. - destruct H. - specialize (IHdf2 σ v0); - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - cut_to H1; trivial. - match_option; [|tauto]. - cut_to IHdf2; trivial. - match_option; [|tauto]. - apply vectoro_to_ovector_not_none; intros. - apply vectoro_to_ovector_not_none; intros. - specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i i0) :: nil) (v, DTfloat)). - cut_to IHdf1; trivial. - match_option; tauto. - - Case "VLossfun"%string. - destruct H. - specialize (IHdf2 σ v); - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - cut_to H1; trivial. - match_option; [|tauto]. - cut_to IHdf2; trivial. - match_option; [|tauto]. - match_option. - apply vectoro_to_ovector_not_none in eqq1. - tauto. - intros. - specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i) :: mk_env_entry (v2, DTfloat) (r i) :: nil) (v1, DTfloat)). - cut_to IHdf1; trivial. - match_option; tauto. - - Case "MLossfun"%string. - destruct H. - specialize (IHdf2 σ v); - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - cut_to H1; trivial. - match_option; [|tauto]. - cut_to IHdf2; trivial. - match_option; [|tauto]. - match_option. - unfold matrixo_to_omatrix in eqq1. - apply vectoro_to_ovector_not_none in eqq1. - tauto. - intros. - apply vectoro_to_ovector_not_none; intros. - specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i i0) :: mk_env_entry (v2, DTfloat) (r i i0) :: nil)(v1, DTfloat)). - cut_to IHdf1; trivial. - match_option; tauto. - Qed. - - Lemma eval_deriv_fully_closed_total {T} (σ:df_env) (df:DefinedFunction UnitAnn T) - (v:var_type): - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> - {d:definition_function_types_interp T | df_eval_deriv σ df v = Some d}. - Proof. - case_eq (df_eval_deriv σ df v); intros. - - now exists d. - - generalize (eval_deriv_fully_closed_not_none σ df v). - intros; simpl in *. - cut_to H1; tauto. - Qed. - - Lemma eval_deriv_genvar_fully_closed_not_none {T} (σ:df_env) (df:DefinedFunction UnitAnn T) - (v:df_env): - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> df_eval_deriv_genvar σ df v <> None. - Proof. - revert σ v. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case - ; intros; simpl in *; - try solve [ - congruence - | - destruct H - ; specialize (IHdf1 σ v); specialize (IHdf2 σ v) - ;cut_to IHdf1; trivial - ;match_option; [|tauto] - ;cut_to IHdf2; trivial - ;match_option; tauto - | - destruct H; - specialize (IHdf1 σ v); specialize (IHdf2 σ v); - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1; - cut_to H1; trivial; - match_option; [|tauto]; - cut_to IHdf1; trivial; - match_option; [|tauto]; - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2; - match_option; [|tauto]; - cut_to IHdf2; trivial; - match_option; tauto - | - generalize (eval_fully_closed_not_none σ df); intros; - specialize (IHdf σ v); - simpl in H0; cut_to H0; trivial; - match_option; [|tauto]; - cut_to IHdf; trivial; - match_option; tauto - | - specialize (IHdf σ v); - generalize (eval_fully_closed_not_none σ df); intros; - simpl in H0; cut_to H0; trivial; - match_option; tauto - ]. - - Case "DVector"%string. - apply vectoro_to_ovector_not_none; intros; apply H. - rewrite vforall_forall in H0; apply H0. - - Case "DMatrix"%string. - apply vectoro_to_ovector_not_none; intros. - apply vectoro_to_ovector_not_none; intros; apply H. - rewrite vforall_forall in H0; specialize (H0 i). - rewrite vforall_forall in H0; apply H0. - - Case "Max"%string. - destruct H; - specialize (IHdf1 σ v); specialize (IHdf2 σ v); - generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1; - cut_to H1; trivial; - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2; - match_option; [|tauto]. - case_eq ( Rle_dec d d0 ); intros. - cut_to IHdf2; trivial. - cut_to IHdf1; trivial. - - Case "VectorApply"%string. - destruct H. - specialize (IHdf2 σ v0); - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - cut_to H1; trivial. - match_option; [|tauto]. - cut_to IHdf2; trivial. - match_option; [|tauto]. - apply vectoro_to_ovector_not_none; intros. - match_option. - specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i) :: nil) (mk_genvar_env v)). - cut_to IHdf1; trivial. - tauto. - - Case "MatrixApply"%string. - destruct H. - specialize (IHdf2 σ v0); - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - cut_to H1; trivial. - match_option; [|tauto]. - cut_to IHdf2; trivial. - match_option; [|tauto]. - apply vectoro_to_ovector_not_none; intros. - apply vectoro_to_ovector_not_none; intros. - specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i i0) :: nil) (mk_genvar_env v)). - cut_to IHdf1; trivial. - match_option; tauto. - - Case "VLossfun"%string. - destruct H. - specialize (IHdf2 σ v); - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - cut_to H1; trivial. - match_option; [|tauto]. - cut_to IHdf2; trivial. - match_option; [|tauto]. - match_option. - apply vectoro_to_ovector_not_none in eqq1. - tauto. - intros. - specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i) :: mk_env_entry (v2, DTfloat) (r i) :: nil) (mk_genvar_env v1)). - cut_to IHdf1; trivial. - match_option; tauto. - - Case "MLossfun"%string. - destruct H. - specialize (IHdf2 σ v); - generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1. - cut_to H1; trivial. - match_option; [|tauto]. - cut_to IHdf2; trivial. - match_option; [|tauto]. - match_option. - unfold matrixo_to_omatrix in eqq1. - apply vectoro_to_ovector_not_none in eqq1. - tauto. - intros. - apply vectoro_to_ovector_not_none; intros. - specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i i0) :: mk_env_entry (v2, DTfloat) (r i i0) :: nil) (mk_genvar_env v1)). - cut_to IHdf1; trivial. - match_option; tauto. - Qed. - - Definition scalarMult (T : definition_function_types) (c : float) := - match T return - definition_function_types_interp T -> definition_function_types_interp T with - | DTfloat => fun f => (c * f)%R - | DTVector n => fun f => fun i => (c * f i)%R - | DTMatrix n m => fun f => fun i j => (c * f i j)%R - end. - - Definition dfti_gen_plus {T : definition_function_types} := - match T return - definition_function_types_interp T -> definition_function_types_interp T -> - definition_function_types_interp T with - | DTfloat => fun f g => (f + g)%R - | DTVector n => fun f g => fun i => (f i + g i)%R - | DTMatrix n m => fun f g => fun i j => (f i j + g i j)%R - end. - - Lemma subvar_addvar_scalar_neq (env : df_env) (oval : float) (s : SubVar) (v: var_type) (grad : definition_function_types_interp (snd v)) : - let sv := (s, DTfloat) in - vartlookup env sv = Some oval -> - v <> sv -> - subvar sv (vart_update env v (addvar v env grad)) oval = 0%R. - Proof. - intros. - unfold subvar; simpl. - rewrite lookup_update_neq; trivial. - rewrite H. - lra. - Qed. - - Lemma subvar_addvar_scalar_eq (env : df_env) (s : SubVar) (oval grad : float) : - let v := (s, DTfloat) in - vartlookup env v = Some oval -> - subvar v (vart_update env v (addvar v env grad)) oval = grad. - Proof. - intros. - unfold subvar; simpl. - rewrite lookup_update. - unfold addvar; simpl. - rewrite H. - lra. - Qed. - - Lemma split_subvar (env1 env2: df_env) (oval val1 : float) (s : SubVar) : - let v := (s, DTfloat) in - vartlookup env1 v = Some val1 -> - subvar v env2 oval = (subvar v env2 val1 + subvar v env1 oval)%R. - Proof. - intros. - unfold subvar; simpl. - rewrite H. - case_eq (vartlookup env2 v); intros. - lra. - lra. - Qed. - - Lemma vsum_nil : vsum vnil = 0%R. - Proof. - reflexivity. - Qed. - - Lemma vsum_mult {n} (v : Vector float n) (c : float) : - (c * vsum v)%R = vsum (fun j => (c * v j)%R). - Proof. - unfold vsum, vector_fold_right1, Datatypes.id; simpl. - induction n; [ | destruct n]. - - repeat rewrite vector_fold_right1_dep_0; lra. - - repeat rewrite vector_fold_right1_dep_1; lra. - - repeat rewrite vector_fold_right1_dep_SSn. - rewrite Rmult_plus_distr_l. - specialize (IHn (vdrop_last v)); simpl in IHn. - rewrite IHn. - f_equal. - apply vector_fold_right1_dep_ext. - intros [i pf]; trivial. - Qed. - - Lemma vsum_plus {m:nat} (v1 v2:Vector R m) : - (vsum v1 + vsum v2)%R = vsum (fun i => (v1 i + v2 i)%R). - Proof. - unfold vsum, vector_fold_right1, Datatypes.id; simpl. - induction m; [ | destruct m]. - - repeat rewrite vector_fold_right1_dep_0; lra. - - repeat rewrite vector_fold_right1_dep_1; lra. - - repeat rewrite vector_fold_right1_dep_SSn. - specialize (IHm (vdrop_last v1) (vdrop_last v2)); simpl in IHm. - rewrite (Rplus_comm (vlast v2)). - rewrite (Rplus_assoc (vlast v1)). - rewrite <- (Rplus_assoc _ _ (vlast v2)). - rewrite IHm. - rewrite (Rplus_comm _ (vlast v2)). - rewrite <- Rplus_assoc. - f_equal. - apply vector_fold_right1_dep_ext. - intros [i pf]; trivial. - Qed. - - Lemma vmap_mult {n} (f: float -> float) (v : Vector float n) (c : float) : - forall i : {n' : nat | n' < n}, - (c * (vmap f v) i)%R = (vmap (fun x => (c * f x)%R) v) i. - Proof. - intros. - rewrite vmap_nth. - now rewrite vmap_nth. - Qed. - - Lemma vsum_ext {n} (v v':Vector float n) : vec_eq v v' -> vsum v = vsum v'. - Proof. - apply vector_fold_right1_ext. - Qed. - - Lemma msum_ext {m n} (mat mat':Matrix float m n) : - (forall i j, mat i j = mat' i j) -> msum mat = msum mat'. - Proof. - intros. - apply vsum_ext; intros ?. - repeat rewrite vmap_nth. - apply vsum_ext; intros ?; auto. - Qed. - - Lemma msum_mult {m n} (mat : Matrix float m n) (c : float) : - (c * msum mat)%R = msum (fun i j => (c * mat i j)%R). - Proof. - unfold msum. - rewrite vsum_mult. - apply vsum_ext; intros i. - repeat rewrite vmap_nth. - now rewrite vsum_mult. - Qed. - - Lemma msum_mmap_mult {m n} (mat : Matrix float m n) (c : float) : - (c * msum mat)%R = msum (mmap (fun x => c * x)%R mat). - Proof. - rewrite msum_mult. - apply msum_ext; intros i j. - now rewrite mmap_nth. - Qed. - - Lemma msum_mmap_div_denom {m n} (mat : Matrix float m n) (c : float) : - msum (mmap (fun u : R => (u / c)%R) mat) = (msum mat / c)%R. - Proof. - transitivity (msum (mmap (fun u : R => (/ c * u)%R) mat)). - - apply msum_ext; intros i j. - repeat rewrite mmap_nth. - lra. - - rewrite <- msum_mmap_mult. - lra. - Qed. - - Lemma vsum0 n : vsum (fun _ : {n' : nat | (n' < n)%nat} => 0%R) = 0%R. - Proof. - generalize (vsum_mult (fun _ : {n' : nat | (n' < n)%nat} => 0%R) 0%R); intros HH. - rewrite Rmult_0_l in HH. - symmetry. - simpl in *. - erewrite vsum_ext; [eassumption | ]. - intro; simpl; lra. - Qed. - - Lemma vsum_unitvector {n} (v:Vector R n) i : - vsum (fun j => (v j * UnitVector n i j)%R) = v i. - Proof. - unfold vsum, vector_fold_right1, Datatypes.id, UnitVector; simpl. - revert n v i. - destruct i. - induction n; [ | destruct n]. - - lia. - - repeat rewrite vector_fold_right1_dep_1. - destruct x; [ | lia]; simpl. - field_simplify. - now erewrite index_pf_irrel. - - repeat rewrite vector_fold_right1_dep_SSn. - unfold vlast, vdrop_last; simpl. - destruct (equiv_dec (S n) x). - + ring_simplify. - simpl. - destruct e. - match goal with - | [|- (_ + ?x)%R = _ ] => replace x with 0%R - end. - * ring_simplify. - now erewrite index_pf_irrel. - * rewrite <- (vsum0 (S n)) at 1. - unfold vsum, vector_fold_right1, Fzero, Datatypes.id; simpl. - apply (@vector_fold_right1_dep_ext (fun _ => R)). - intros [??]. - destruct (equiv_dec x (S n)). - -- destruct e. - lia. - -- lra. - + ring_simplify. - unfold equiv, complement in c. - assert (pf:x < S n) by lia. - specialize (IHn (vdrop_last v) pf). - simpl in IHn. - erewrite index_pf_irrel. - rewrite <- IHn. - apply (@vector_fold_right1_dep_ext (fun _ => R)). - now intros [??]. - Qed. - - Lemma msum_unitmatrix {m n} (v:Matrix R m n) i j : - msum (fun k l => (v k l * UnitMatrix m n i j k l)%R) = v i j. - Proof. - unfold msum. - unfold UnitMatrix. - rewrite (vsum_ext _ ( - (fun (k : {n' : nat | n' < m}) => @vsum floatish_R _ - (fun (l : {m' : nat | m' < n}) => - (v k l * - (if equiv_dec (` k) (` i) then if equiv_dec (` l) (` j) then 1%R else 0%R else 0%R))%R)) - - )) - by (intros ?; now rewrite vmap_nth). - rewrite (vsum_ext _ ( - (fun (k : {n' : nat | n' < m}) => (if equiv_dec (` k) (` i) then - @vsum floatish_R _ - (fun (l : {m' : nat | m' < n}%nat) => - ((v k) l * - if equiv_dec (` l) (` j) then 1%R else 0%R))%R else 0%R)) - - )). - - rewrite (vsum_ext _ ( - (fun (k : {n' : nat | n' < m}) => (if equiv_dec (` k) (` i) then - v k j else 0%R)) - - )). - + rewrite (vsum_ext _ ( - (fun (k : {n' : nat | n' < m}) => ((transpose v) j k * @UnitVector floatish_R m i k)%R) - - )). - * now rewrite vsum_unitvector. - * unfold UnitVector; intros ?; simpl. - dest_eqdec; unfold transpose; simpl - ; lra. - + intros ?. - dest_eqdec; trivial. - apply vsum_unitvector. - - intros ?. - dest_eqdec; trivial. - rewrite <- (vsum0 n) at 1. - apply vsum_ext. - intros ?; lra. - Qed. - - Ltac vectoro_assert_forall_in H i - := match type of H with vectoro_to_ovector ?x = Some ?y => - assert (forall i, x i = Some (y i)) end. - - Lemma vartlookup_list_env_iter {A} - (s: SubVar) - (f : A -> df_env -> option df_env) - (l : list A) (env fenv: df_env): - list_env_iter f (Some env) l = Some fenv -> - vartlookup env (s, DTfloat) <> None -> - (forall (env fenv: df_env) (i:A), - f i env = Some fenv -> - vartlookup env (s, DTfloat) <> None -> - vartlookup fenv (s, DTfloat) <> None) -> - vartlookup fenv (s, DTfloat) <> None. - Proof. - intros. - revert H0 H. - generalize env. - induction l. - - simpl; intros. - now invcs H. - - simpl; intros. - generalize (list_env_iter_none f l); intros. - assert (f a env0 <> None); [congruence | ]. - case_eq (f a env0); [|congruence]. - intros. - apply (IHl d). - + specialize (H1 env0 d a). - now apply H1. - + now rewrite H4 in H. - Qed. - -(* Theorem df_eval_deriv_same {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (s:SubVar) : - let v := (s, DTfloat) in - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> - df_eval_deriv σ df v = df_eval σ (df_deriv df v). - Proof. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case - ; simpl in *;trivial; intros - ; try (destruct H; rewrite IHdf1; trivial; rewrite IHdf2; trivial) - ; try (rewrite IHdf; trivial; do 2 match_option). - - Case "DVector"%string. - f_equal. - apply functional_extensionality; intros. - rewrite vforall_forall in H0. - specialize (H x0); simpl in H. - specialize (H0 x0). - apply H; trivial. - - Case "DMatrix"%string. - f_equal. - apply functional_extensionality; intros. - apply functional_extensionality; intros. - rewrite vforall_forall in H0. - specialize (H x0); simpl in H. - specialize (H0 x0). - rewrite vforall_forall in H0. - apply H; trivial. - - case "Times"%string. - intros; do 4 match_option. - - Case "Divide"%string. - intros; do 4 match_option. - - Case "Sign"%string. - match_option. - generalize (eval_deriv_fully_closed_not_none σ df (s, DTfloat)); tauto. - - Case "PSign"%string. - match_option. - generalize (eval_deriv_fully_closed_not_none σ df (s, DTfloat)); tauto. - - Case "Max"%string. - assert (df_eval σ df1 <> None) by (apply eval_fully_closed_not_none; trivial). - assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial). - match_option; [|congruence]. - match_option; [|congruence]. - assert (df_eval σ (df_deriv df1 (s, DTfloat)) <> None). - apply eval_fully_closed_not_none. - apply fully_closed_deriv; trivial. - assert (df_eval σ (df_deriv df2 (s, DTfloat)) <> None). - apply eval_fully_closed_not_none. - apply fully_closed_deriv; trivial. - destruct (df_eval σ (df_deriv df1 (s, DTfloat))); [|congruence]. - destruct (df_eval σ (df_deriv df2 (s, DTfloat))); [|congruence]. - unfold pos_sign; simpl. - case_eq (Rle_dec d d0); intros; f_equal. - destruct (Rge_dec (d0 - d) 0); lra. - destruct (Rge_dec (d0 - d) 0); lra. - - Case "VectorDot"%string. - do 4 match_option. - rewrite vsum_plus. - do 2 f_equal. - f_equal. - apply functional_extensionality; intros. - lra. - - Case "MatrixVectorMult"%string. - do 4 match_option. - f_equal. - apply functional_extensionality; intros. - unfold matrix_vector_mult. - rewrite vsum_plus; simpl. - f_equal. - apply functional_extensionality; intros. - lra. - - Case "MatrixMult"%string. - do 4 match_option. - f_equal. - apply functional_extensionality; intros. - apply functional_extensionality; intros. - unfold matrix_mult. - rewrite vsum_plus; simpl; f_equal. - apply functional_extensionality; intros. - lra. - - Case "VectorScalMult"%string. - do 4 match_option. - f_equal. - apply functional_extensionality; intros. - lra. - - Case "MatrixScalMult"%string. - do 4 match_option. - f_equal. - apply functional_extensionality; intros. - apply functional_extensionality; intros. - lra. - - Case "VectorApply"%string. - destruct H. - assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial). - match_option; [|congruence]. - rewrite IHdf2; trivial. - match_option. - + f_equal. - apply functional_extensionality; intros. - assert ( df_eval_deriv [mk_env_entry (v, DTfloat) (d x)] df1 (v, DTfloat) = - df_eval σ (df_subst (df_deriv df1 (v, DTfloat)) (v, DTfloat) - (VectorElem () df2 x))). - XXX - now rewrite H2. - + assert ( df_eval σ (df_deriv df2 (s, DTfloat)) <> None ); [|tauto]. - apply eval_fully_closed_not_none. - apply fully_closed_deriv; trivial. - - Case "MatrixApply"%string. - destruct H. - assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial). - match_option; [|congruence]. - rewrite IHdf2; trivial. - match_option. - + f_equal. - apply functional_extensionality; intros. - apply functional_extensionality; intros. - assert ( df_eval_deriv [mk_env_entry (v, DTfloat) (d x x0)] df1 (v, DTfloat) = - df_eval σ (df_subst (df_deriv df1 (v, DTfloat)) (v, DTfloat) - (MatrixElem () df2 x x0))). - XXX - now rewrite H2. - + assert ( df_eval σ (df_deriv df2 (s, DTfloat)) <> None ); [|tauto]. - apply eval_fully_closed_not_none. - apply fully_closed_deriv; trivial. - - Case "VLossfun"%string. - destruct H. - assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial). - match_option; [|congruence]. - rewrite IHdf2; trivial. - match_option. - do 2 match_option. - + do 2 f_equal. - apply functional_extensionality; intros. - XXX - + generalize (vectoro_to_ovector_exists_None eqq2). - intros; destruct H2. - assert (df_eval - σ - (df_subst (df_subst (df_deriv df1 (v1, DTfloat)) (v1, DTfloat) (VectorElem () df2 x)) - (v2, DTfloat) (Number () (r x))) <> None); [|tauto]. - apply eval_fully_closed_not_none. - apply fully_closed_subst; simpl; [|trivial]. - apply fully_closed_subst; simpl. - * apply fully_closed_deriv. - now apply fully_closed_over_pair. - * now apply fully_closed_over_cons. - + generalize (vectoro_to_ovector_exists_None eqq1). - intros; destruct H2. - match_option_in e. - assert (df_eval_deriv [mk_env_entry (v1, DTfloat) (d x); mk_env_entry (v2, DTfloat) (r x)] - df1 (v1, DTfloat) <> None); [|tauto]. - apply eval_deriv_fully_closed_not_none; trivial. - - Case "MLossfun"%string. - destruct H. - assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial). - match_option; [|congruence]. - rewrite IHdf2; trivial. - match_option. - do 2 match_option. - + do 2 f_equal. - XXX - + generalize (vectoro_to_ovector_exists_None eqq2); intros; destruct H2. - generalize (vectoro_to_ovector_exists_None e); intros; destruct H2. - match_option_in e0. - match_option_in eqq3. - assert (df_eval σ - (df_subst - (df_subst (df_deriv df1 (v1, DTfloat)) (v1, DTfloat) (MatrixElem () df2 x x0)) - (v2, DTfloat) (Number () (r x x0))) <> None); [|tauto]. - apply eval_fully_closed_not_none. - apply fully_closed_subst; simpl; [|trivial]. - apply fully_closed_subst; simpl. - * apply fully_closed_deriv. - now apply fully_closed_over_pair. - * now apply fully_closed_over_cons. - + generalize (vectoro_to_ovector_exists_None eqq1). - intros; destruct H2. - generalize (vectoro_to_ovector_exists_None e). - intros; destruct H2. - match_option_in e0. - assert (df_eval_deriv [mk_env_entry (v1, DTfloat) (d x x0); - mk_env_entry (v2, DTfloat) (r x x0)] - df1 (v1, DTfloat) <> None); [|tauto]. - apply eval_deriv_fully_closed_not_none; trivial. - - - *) - - Theorem df_eval_deriv_scalar_same (σ:df_env) (df:DefinedFunction UnitAnn DTfloat) (s:SubVar) : - let v := (s, DTfloat) in - let vl := map (fun ve => projT1 ve) σ in - is_scalar_function df -> - fully_closed_over df vl -> - df_eval_deriv σ df v = df_eval σ (df_deriv df v). - Proof. - simpl. - intros is_scalar. - generalize is_scalar. - pattern df. - revert df is_scalar. - DefinedFunction_scalar_cases (apply is_scalar_function_ind) Case - ; simpl; trivial; intros - ; try (destruct H1; destruct is_scalar; - specialize (H H3 H1); specialize (H0 H4 H2); now rewrite H, H0) - ; try (rewrite H; trivial; do 2 match_option). - - Case "Times"%string. - destruct H1. - destruct is_scalar. - cut_to H; trivial. - cut_to H0; trivial. - rewrite H, H0. - do 4 match_option. - - Case "Divide"%string. - destruct H1. - destruct is_scalar. - cut_to H; trivial. - cut_to H0; trivial. - rewrite H, H0. - do 4 match_option. - - Case "Sign"%string. - match_option. - assert ( df_eval_deriv σ e (s, DTfloat) <> None); [|tauto]. - apply eval_deriv_fully_closed_not_none; trivial. - - Case "PSign"%string. - match_option. - assert ( df_eval_deriv σ e (s, DTfloat) <> None); [|tauto]. - apply eval_deriv_fully_closed_not_none; trivial. - - Case "Max"%string. - destruct is_scalar; destruct H1. - cut_to H; trivial. - cut_to H0; trivial. - rewrite H, H0. - assert (df_eval σ l <> None) by (apply eval_fully_closed_not_none; trivial). - assert (df_eval σ r <> None) by (apply eval_fully_closed_not_none; trivial). - match_option; [|congruence]. - match_option; [|congruence]. - assert (df_eval σ (df_deriv l (s, DTfloat)) <> None). - apply eval_fully_closed_not_none. - apply fully_closed_deriv; trivial. - assert (df_eval σ (df_deriv r (s, DTfloat)) <> None). - apply eval_fully_closed_not_none. - apply fully_closed_deriv; trivial. - destruct (df_eval σ (df_deriv l (s, DTfloat))); [|congruence]. - destruct (df_eval σ (df_deriv r (s, DTfloat))); [|congruence]. - unfold pos_sign; simpl. - case_eq (Rle_dec d d0); intros; f_equal. - destruct (Rge_dec (d0 - d) 0); lra. - destruct (Rge_dec (d0 - d) 0); lra. - Qed. - - Lemma vartlookup_list_env_iter2 {A} - (s: SubVar) - {f : A -> df_env -> option df_env} - {l : list A} {env fenv: df_env}: - list_env_iter f (Some env) l = Some fenv -> - vartlookup env (s, DTfloat) <> None -> - (forall (env fenv: df_env) (i:A), - f i env = Some fenv -> - vartlookup env (s, DTfloat) <> None -> - vartlookup fenv (s, DTfloat) <> None) -> - vartlookup fenv (s, DTfloat) <> None. - Proof. - apply (vartlookup_list_env_iter s f l env fenv). - Qed. - - Lemma scalarMult_list_env_iter - (s: SubVar) (c val1 val2:float) (A :Type) - (f g : A -> df_env -> option df_env) - (l : list A) (env1 env2 fenv1 fenv2: df_env): - list_env_iter f (Some env1) l = Some fenv1 -> - list_env_iter g (Some env2) l = Some fenv2 -> - vartlookup env1 (s, DTfloat) = Some val1 -> - vartlookup env2 (s, DTfloat) = Some val2 -> - (forall (i:A) (env1 env2 fenv1 fenv2: df_env) (v1 v2: float), - vartlookup env1 (s, DTfloat) = Some v1 -> - vartlookup env2 (s, DTfloat) = Some v2 -> - f i env1 = Some fenv1 -> g i env2 = Some fenv2 -> - subvar (s, DTfloat) fenv1 v1 = (c * subvar (s, DTfloat) fenv2 v2)%R) -> - (forall (env fenv: df_env) (i:A), - f i env = Some fenv -> - vartlookup env (s, DTfloat) <> None -> - vartlookup fenv (s, DTfloat) <> None) -> - (forall (env fenv: df_env) (i:A), - g i env = Some fenv -> - vartlookup env (s, DTfloat) <> None -> - vartlookup fenv (s, DTfloat) <> None) -> - subvar (s, DTfloat) fenv1 val1 = (c * subvar (s, DTfloat) fenv2 val2)%R. - Proof. - intros. - generalize (vartlookup_list_env_iter s f l env1 fenv1); intros. - assert (vartlookup env1 (s, DTfloat) <> None). - rewrite H1; discriminate. - assert (vartlookup env2 (s, DTfloat) <> None). - rewrite H2; discriminate. - specialize (H6 H H7 H4). - generalize (vartlookup_list_env_iter s g l env2 fenv2); intros. - specialize (H9 H0 H8 H5). - revert H1 H2 H H0. - generalize env1 env2 val1 val2. - induction l. - - intros. - unfold subvar; simpl. - unfold list_env_iter in H; simpl in H. - unfold list_env_iter in H0; simpl in H0. - invcs H; invcs H0. - rewrite H1; rewrite H2; lra. - - simpl; intros. - generalize (list_env_iter_none f l); intros. - assert (f a env0 <> None); [congruence | ]. - case_eq (f a env0); [intros|congruence]. - generalize (list_env_iter_none g l); intros. - assert (g a env3 <> None); [congruence | ]. - case_eq (g a env3); [intros|congruence]. - assert (vartlookup d (s, DTfloat) <> None). - apply (H4 env0 d a); trivial; congruence. - assert (vartlookup d0 (s, DTfloat) <> None). - apply (H5 env3 d0 a); trivial; congruence. - case_eq (vartlookup d (s, DTfloat)); [intros | tauto]. - case_eq (vartlookup d0 (s, DTfloat)); [intros | tauto]. - specialize (IHl d d0 d1 d2). - specialize (H3 a env0 env3 d d0 val0 val3). - rewrite (split_subvar d fenv1 val0 d1); trivial. - rewrite (split_subvar d0 fenv2 val3 d2); trivial. - specialize (H3 H1 H2 H12 H15). - rewrite H12 in H. - rewrite H15 in H0. - specialize (IHl H18 H19 H H0). - lra. - Qed. - - Lemma list_env_iter_subvar_env2 - (s: SubVar) (val1 val2:float) (A :Type) - (f g : A -> df_env -> option df_env) - (l : list A) (env1 env2 fenv1 fenv2: df_env): - list_env_iter f (Some env1) l = Some fenv1 -> - list_env_iter g (Some env2) l = Some fenv2 -> - vartlookup env1 (s, DTfloat) = Some val1 -> - vartlookup env2 (s, DTfloat) = Some val2 -> - (forall (i:A) (env1 env2 fenv1 fenv2: df_env) (v1 v2: float), - vartlookup env1 (s, DTfloat) = Some v1 -> - vartlookup env2 (s, DTfloat) = Some v2 -> - f i env1 = Some fenv1 -> g i env2 = Some fenv2 -> - subvar (s, DTfloat) fenv1 v1 = subvar (s, DTfloat) fenv2 v2) -> - (forall (env fenv: df_env) (i:A), - f i env = Some fenv -> - vartlookup env (s, DTfloat) <> None -> - vartlookup fenv (s, DTfloat) <> None) -> - (forall (env fenv: df_env) (i:A), - g i env = Some fenv -> - vartlookup env (s, DTfloat) <> None -> - vartlookup fenv (s, DTfloat) <> None) -> - subvar (s, DTfloat) fenv1 val1 = subvar (s, DTfloat) fenv2 val2. - Proof. - intros. - generalize (scalarMult_list_env_iter s 1%R val1 val2 A f g l env1 env2 fenv1 fenv2). - intros. - specialize (H6 H H0 H1 H2). - cut_to H6. - now replace (1 * subvar (s, DTfloat) fenv2 val2)%R with (subvar (s, DTfloat) fenv2 val2) in H6 by lra. - intros. - replace (1 * subvar (s, DTfloat) fenv3 v2)%R with (subvar (s, DTfloat) fenv3 v2) by lra. - apply (H3 i env0 env3 fenv0 fenv3); trivial. - apply H4. - apply H5. - Qed. - - Lemma scalarMult_backprop_grad_scalar {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (s: SubVar) (grad_env1 grad_env2:df_env) (grad : definition_function_types_interp T) (c:float) : - let v := (s, DTfloat) in - vartlookup grad_env1 v <> None -> vartlookup grad_env2 v <> None -> - df_eval_backprop_deriv σ df grad_env1 (scalarMult T c grad) <> None -> - df_eval_backprop_deriv σ df grad_env2 grad <> None -> - df_eval_backprop_delta σ df v grad_env1 (scalarMult T c grad) = - lift (fun e => scalarMult (snd v) c e) (df_eval_backprop_delta σ df v grad_env2 grad). - Proof. - revert grad_env1 grad_env2. - unfold df_eval_backprop_delta. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case - ; simpl; intros grad_env1 grad_env2 neq1 neq2; intros. - - Case "Number"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [|tauto]. - intros; simpl; f_equal. - unfold subvar; simpl. - match_destr; match_destr. - inversion H1; subst. - inversion H2; subst. - lra. - - Case "Constant"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - intros; simpl; f_equal. - unfold subvar; simpl. - match_destr; match_destr. - inversion H1; subst. - inversion H2; subst. - lra. - - Case "DVector"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - intros; simpl. - unfold lift. - match_option; [|tauto]. - case_eq (two_vector_env_iter_alt - (fun (x0 : DefinedFunction Ann DTfloat) (g : R) (genv : df_env) => - df_eval_backprop_deriv σ x0 genv g) grad_env2 x grad); [|tauto]. - unfold two_vector_env_iter_alt in *. - intros; f_equal. - apply (scalarMult_list_env_iter - s c d0 d {n' : nat | n' < n} - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (c * grad i)%R) - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (grad i)) - (bounded_seq0 n) grad_env1 grad_env2); trivial. - + intros. - specialize (H i (grad i) env1 env2). - assert (vartlookup env1 (s, DTfloat) <> None); [congruence|]. - assert (vartlookup env2 (s, DTfloat) <> None); [congruence|]. - specialize (H H9 H10). - assert (df_eval_backprop_deriv σ (x i) env1 (c * grad i)%R <> None); [congruence|]. - assert (df_eval_backprop_deriv σ (x i) env2 (grad i) <> None); [congruence|]. - specialize (H H11 H12). - unfold lift in H; simpl in H. - rewrite H5, H6, H7, H8 in H. - now inversion H. - + intros. - now apply (df_eval_backprop_deriv_preserves_lookup_not_none H5). - + intros. - now apply (df_eval_backprop_deriv_preserves_lookup_not_none H5). - - Case "DMatrix"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - intros; simpl. - unfold lift. - match_option; [|tauto]. - case_eq (two_matrix_env_iter_alt - (fun (x0 : DefinedFunction Ann DTfloat) (g : R) (genv : df_env) => - df_eval_backprop_deriv σ x0 genv g) grad_env2 x grad); [|tauto]. - intros; f_equal. - unfold two_matrix_env_iter_alt in *. - apply (scalarMult_list_env_iter - s c d0 d {n' : nat | n' < n} - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (c * grad i j)%R) - (Some env) (bounded_seq0 m)) - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) (Some env) - (bounded_seq0 m)) - (bounded_seq0 n) grad_env1 grad_env2); trivial. - + intros. - apply (scalarMult_list_env_iter - s c v1 v2 {m' : nat | m' < m} - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (c * grad i j)%R) - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (bounded_seq0 m) env1 env2); trivial. - * intros. - specialize (H i i0 (grad i i0) env0 env3). - assert (vartlookup env0 (s, DTfloat) <> None); [congruence|]. - assert (vartlookup env3 (s, DTfloat) <> None); [congruence|]. - specialize (H H13 H14). - assert (df_eval_backprop_deriv σ (x i i0) env0 (c * grad i i0)%R <> None); [congruence|]. - assert (df_eval_backprop_deriv σ (x i i0) env3 (grad i i0) <> None); [congruence|]. - specialize (H H15 H16). - unfold lift in H; simpl in H. - rewrite H9, H10, H11, H12 in H. - now inversion H. - * intros. - now apply (df_eval_backprop_deriv_preserves_lookup_not_none H9). - * intros. - now apply (df_eval_backprop_deriv_preserves_lookup_not_none H9). - + intros. - apply (vartlookup_list_env_iter - s - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (c * grad i j)%R) - (bounded_seq0 m) env fenv); trivial; intros. - now apply (df_eval_backprop_deriv_preserves_lookup_not_none H7). - + intros. - apply (vartlookup_list_env_iter - s - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (bounded_seq0 m) env fenv); trivial; intros. - now apply (df_eval_backprop_deriv_preserves_lookup_not_none H7). - - Case "Var"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - intros. - destruct (vart_dec v (s, DTfloat)). - + subst; simpl. - rewrite H2; simpl. - rewrite subvar_addvar_scalar_eq; trivial. - rewrite H1; simpl. - now rewrite subvar_addvar_scalar_eq. - + case_eq (vartlookup grad_env1 v); intros; simpl. - * case_eq (vartlookup grad_env2 v); intros; simpl; f_equal. - -- rewrite subvar_addvar_scalar_neq; trivial. - rewrite subvar_addvar_scalar_neq; trivial. - lra. - -- rewrite subvar_addvar_scalar_neq; trivial. - unfold subvar; simpl. - rewrite H1. - lra. - * case_eq (vartlookup grad_env2 v); intros; simpl; f_equal. - -- rewrite subvar_addvar_scalar_neq; trivial. - unfold subvar; simpl. - rewrite H2. - lra. - -- unfold subvar; simpl. - rewrite H2; rewrite H1. - lra. - - Case "Plus"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - specialize (IHdf1 grad grad_env1 grad_env2); intros. - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 (c * grad)%R); intros. - rewrite H3 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros. - rewrite H4 in H0; simpl in H0. - + specialize (IHdf2 grad d1 d2). - rewrite H3 in IHdf1; rewrite H4 in IHdf1. - assert (vartlookup d1 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1). - case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2). - case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros. - rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d1 (c * grad)%R); [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d2 grad); [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d1 d5 d0 d3) by trivial. - rewrite (split_subvar d2 d6 d d4) by trivial. - rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by - (apply IHdf2; trivial; discriminate). - inversion H11; inversion H12. - rewrite H14; rewrite H15; lra. - + now rewrite H4 in H0. - + now rewrite H3 in H. - - Case "Minus"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - specialize (IHdf1 grad grad_env1 grad_env2); intros. - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 (c * grad)%R); intros. - rewrite H3 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros. - rewrite H4 in H0; simpl in H0. - + specialize (IHdf2 (-grad)%R d1 d2). - rewrite H3 in IHdf1; rewrite H4 in IHdf1. - assert (vartlookup d1 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1). - case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2). - case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros. - rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d1 (-(c * grad))%R); [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d2 (-grad)%R); [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d1 d5 d0 d3) by trivial. - rewrite (split_subvar d2 d6 d d4) by trivial. - replace (c * -grad)%R with (-(c*grad))%R in IHdf2 by lra. - rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by - (apply IHdf2; trivial; discriminate). - inversion H11; inversion H12. - rewrite H14; rewrite H15; lra. - + now rewrite H4 in H0. - + now rewrite H3 in H. - - Case "Times"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - case_eq ( df_eval σ df1); [|tauto]. - case_eq ( df_eval σ df2); [|tauto]; intros. - specialize (IHdf1 (d * grad)%R grad_env1 grad_env2). - rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 (d * (c * grad))%R); intros. - rewrite H1, H2, H5 in H; simpl in H. - rewrite H1, H2 in H0; simpl in H0. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 (d * grad)%R); intros. - rewrite H6 in H0; simpl in H0. - + specialize (IHdf2 (d0 * grad)%R d3 d4). - replace (c * (d * grad))%R with (d * (c * grad))%R in IHdf1 by lra. - rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1). - case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2). - case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros. - rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d3 (d0 * (c * grad))%R); [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d4 (d0 * grad)%R); [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d3 d7 d2 d5) by trivial. - rewrite (split_subvar d4 d8 d1 d6) by trivial. - replace (c * (d0 * grad))%R with (d0 * (c*grad))%R in IHdf2 by lra. - rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d3 d2) = Some (c * subvar (s, DTfloat) d4 d1)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d7 d5) = Some (c * subvar (s, DTfloat) d8 d6)%R) by - (apply IHdf2; trivial; discriminate). - inversion H13; inversion H14. - rewrite H16; rewrite H17; lra. - + now rewrite H6 in H0. - + rewrite H2 in H; rewrite H1 in H. - now rewrite H5 in H. - - Case "Divide"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - case_eq ( df_eval σ df1); [|tauto]. - case_eq ( df_eval σ df2); [|tauto]; intros. - specialize (IHdf1 (grad / d)%R grad_env1 grad_env2). - rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 (c * grad / d)%R); intros. - rewrite H1 in H; rewrite H2 in H; simpl in H. - rewrite H1 in H0; rewrite H2 in H0; simpl in H0. - rewrite H5 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 (grad / d)%R); intros. - rewrite H6 in H0; simpl in H0. - + specialize (IHdf2 (- d0 / (d*d) * grad)%R d3 d4). - replace (c * (grad / d))%R with (c * grad / d )%R in IHdf1 by lra. - rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1). - case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2). - case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros. - rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d3 (- d0 /(d * d) * (c * grad))%R); [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d4 (- d0 / (d * d) * grad)%R); [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d3 d7 d2 d5) by trivial. - rewrite (split_subvar d4 d8 d1 d6) by trivial. - replace (c * (-d0 / (d * d) * grad))%R with (- d0/(d * d) * (c*grad))%R in IHdf2 by lra. - rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d3 d2) = Some (c * subvar (s, DTfloat) d4 d1)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d7 d5) = Some (c * subvar (s, DTfloat) d8 d6)%R) by - (apply IHdf2; trivial; discriminate). - inversion H13; inversion H14. - rewrite H16; rewrite H17; lra. - + now rewrite H6 in H0. - + rewrite H2 in H; rewrite H1 in H. - now rewrite H5 in H. - - Case "Square"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - case_eq (df_eval σ df); [ | tauto]; intros. - specialize (IHdf (2 * d1 * grad)%R grad_env1 grad_env2); simpl in *. - rewrite H1 in IHdf; rewrite H2 in IHdf. - replace (2 * d1 * (c * grad))%R with (c * (2 * d1 * grad))%R by lra. - rewrite H3 in H; rewrite H3 in H0; simpl in *. - replace (2 * d1 * (c * grad))%R with (c * (2 * d1 * grad))%R in H by lra. - now apply IHdf. - - Case "Exp"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - case_eq (df_eval σ df); [ | tauto]; intros. - specialize (IHdf (grad * exp d1)%R grad_env1 grad_env2); simpl in *. - replace (c * grad * exp d1)%R with (c * (grad * exp d1))%R by lra. - rewrite H1 in IHdf; rewrite H2 in IHdf. - rewrite H3 in H; rewrite H3 in H0. - replace (c * grad * exp d1)%R with (c * (grad * exp d1))%R in H by lra. - now apply IHdf. - - Case "Log"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - case_eq (df_eval σ df); [ | tauto]; intros. - specialize (IHdf (grad / d1)%R grad_env1 grad_env2 ); simpl in *. - replace (c * grad / d1)%R with (c * (grad / d1))%R by lra. - rewrite H1 in IHdf; rewrite H2 in IHdf. - rewrite H3 in H; rewrite H3 in H0. - replace (c * grad / d1)%R with (c * (grad / d1))%R in H by lra. - now apply IHdf. - - Case "Abs"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - case_eq (df_eval σ df); [ | tauto]; intros. - specialize (IHdf (grad * sign d1)%R grad_env1 grad_env2); simpl in *. - replace (c * grad * (@sign floatish_R d1))%R - with (c * (grad * (@sign floatish_R d1)))%R by lra. - rewrite H1 in IHdf; rewrite H2 in IHdf. - rewrite H3 in H; rewrite H3 in H0. - replace (c * grad * (@sign floatish_R d1))%R - with (c * (grad * (@sign floatish_R d1)))%R in H by lra. - now apply IHdf. - - Case "Sign"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - specialize (IHdf (0)%R grad_env1 grad_env2); simpl in *. - replace (0%R) with (c * 0)%R at 1 by lra. - rewrite H1 in IHdf; rewrite H2 in IHdf. - replace (0%R) with (c * 0)%R in H by lra. - now apply IHdf. - - Case "PSign"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - specialize (IHdf (0)%R grad_env1 grad_env2); simpl in *. - replace (0%R) with (c * 0)%R at 1 by lra. - rewrite H1 in IHdf; rewrite H2 in IHdf. - replace (0%R) with (c * 0)%R in H by lra. - now apply IHdf. - - Case "Max"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - case_eq (df_eval σ df1); [ | tauto]; intros. - case_eq (df_eval σ df2); [ | tauto]; intros. - rewrite H3 in H; rewrite H3 in H0. - rewrite H4 in H; rewrite H4 in H0. - case_eq (Rle_dec d1 d2); intros. - + specialize (IHdf2 grad grad_env1 grad_env2); simpl in *. - rewrite H1 in IHdf2; rewrite H2 in IHdf2. - rewrite H5 in H; rewrite H5 in H0. - now apply IHdf2. - + specialize (IHdf1 grad grad_env1 grad_env2); simpl in *. - rewrite H1 in IHdf1; rewrite H2 in IHdf1. - rewrite H5 in H; rewrite H5 in H0. - now apply IHdf1. - - Case "VectorDot"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - case_eq (df_eval σ df1); [ | tauto]; intros. - case_eq (df_eval σ df2); [ | tauto]; intros. - specialize (IHdf1 (vmap (fun rv => (rv * grad)%R) d2) grad_env1 grad_env2). - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 - (vmap (fun rv : R => (rv * (c * grad))%R) d2)); intros. - rewrite H3 in H; rewrite H4 in H; rewrite H5 in H; simpl in H. - rewrite H3 in H0; rewrite H4 in H0; simpl in H0. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 - (vmap (fun rv : R => (rv * grad)%R) d2)); intros. - rewrite H6 in H0; simpl in H0. - + specialize (IHdf2 (vmap (fun lv => (lv *grad)%R) d1) d3 d4). - replace (fun i => (c * vmap (fun rv : R => rv * grad) d2 i)%R) with - (vmap (fun rv : R => (rv * (c * grad))%R) d2) in IHdf1. - rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1). - case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2). - case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros. - rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d3 - (vmap (fun lv : R => (lv * (c * grad))%R) d1)) - ; [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d4 (vmap (fun lv : R => (lv * grad)%R) d1)) - ; [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d3 d7 d d5) by trivial. - rewrite (split_subvar d4 d8 d0 d6) by trivial. - replace - (fun i : {n' : nat | n' < n} => (c * vmap (fun lv : R => lv * grad) d1 i)%R) with - (vmap (fun lv : R => (lv * (c * grad))%R) d1) in IHdf2. - rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d3 d) = Some (c * subvar (s, DTfloat) d4 d0)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d7 d5) = Some (c * subvar (s, DTfloat) d8 d6)%R) by - (apply IHdf2; trivial; discriminate). - inversion H13; inversion H14. - rewrite H16; rewrite H17; lra. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vmap_mult. - assert ((fun lv => (lv * (c * grad))%R) = (fun x0 => (c * (x0 * grad))%R)). - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - now rewrite H13. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vmap_mult. - assert ((fun rv => (rv * (c * grad))%R) = (fun x0 => (c * (x0 * grad))%R)). - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - now rewrite H7. - + now rewrite H6 in H0. - + rewrite H3 in H; rewrite H4 in H. - now rewrite H5 in H. - - Case "VectorSum"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - specialize (IHdf (ConstVector n grad) grad_env1 grad_env2). - rewrite H1 in IHdf; rewrite H2 in IHdf; simpl in IHdf. - replace (ConstVector n (c * grad)%R) with - (fun i => (c * ConstVector n grad i)%R). - now apply IHdf. - now unfold ConstVector. - - Case "MatrixSum"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - specialize (IHdf (ConstMatrix m n grad) grad_env1 grad_env2). - rewrite H1 in IHdf; rewrite H2 in IHdf; simpl in IHdf. - replace (ConstMatrix m n (c * grad)%R) with - (fun i j => (c * ConstMatrix m n grad i j)%R). - now apply IHdf. - now unfold ConstMatrix. - - Case "VectorElem"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - specialize (IHdf (fun k => - if equiv_dec (` k) (` i) then grad else 0%R) grad_env1 grad_env2). - rewrite H1 in IHdf; rewrite H2 in IHdf; simpl in *. - replace (fun i0 => (c * (if equiv_dec (` i0) (` i) then grad else 0))%R) with - (fun k : {n' : nat | n' < n} => - if equiv_dec (` k) (` i) then (c * grad)%R else 0%R) in IHdf. - now rewrite IHdf. - apply FunctionalExtensionality.functional_extensionality; intros. - destruct (equiv_dec (` x) (` i)); lra. - - Case "MatrixElem"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - specialize (IHdf (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) => - if equiv_dec (` k1) (` i) - then if equiv_dec (` k2) (` j) then grad else 0%R else 0%R) - grad_env1 grad_env2). - rewrite H1 in IHdf; rewrite H2 in IHdf; simpl in *. - replace (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) => - if equiv_dec (` k1) (` i) - then if equiv_dec (` k2) (` j) then (c * grad)%R else 0%R else 0%R) with - (fun (i0 : {n' : nat | n' < m}) (j0 : {m' : nat | m' < n}) => - (c * - (if equiv_dec (` i0) (` i) - then if equiv_dec (` j0) (` j) then grad else 0 - else 0))%R) in *. - + now rewrite IHdf. - + apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - destruct (equiv_dec (` x) (` i)); [|lra]. - destruct (equiv_dec (` x0) (` j)); lra. - - Case "MatrixVectorMult"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - case_eq ( df_eval σ df1); [|tauto]. - case_eq ( df_eval σ df2); [|tauto]; intros. - specialize (IHdf1 (fun i j => (grad i * d j)%R) grad_env1 grad_env2). - rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 - (fun i j => (c * grad i * d j)%R)); intros. - rewrite H1 in H; rewrite H2 in H; simpl in H. - rewrite H1 in H0; rewrite H2 in H0; simpl in H0. - rewrite H5 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 - (fun i j => (grad i * d j)%R)); intros. - rewrite H6 in H0; simpl in H0. - + specialize (IHdf2 (matrix_vector_mult (fun i j => d0 j i) grad) d3 d4). - replace - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (c * (grad i * d j))%R) with - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (c * grad i * d j)%R) in IHdf1. - * rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1). - case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2). - case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros. - rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *. - unfold lift; match_case; [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d4 - (matrix_vector_mult (fun i j => d0 j i) grad)) - ; [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d3 d7 d2 d5) by trivial. - rewrite (split_subvar d4 d8 d1 d6) by trivial. - replace (fun i : {n' : nat | n' < n} => - (c * - (@matrix_vector_mult floatish_R _ _ - (fun (i0 : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) => - d0 j i0) grad) i)%R) with - (matrix_vector_mult - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => d0 j i) - (fun i : {n' : nat | n' < m} => (c * grad i)%R)) in IHdf2. - -- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d3 d2) = - Some (c * subvar (s, DTfloat) d4 d1)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d7 d5) = - Some (c * subvar (s, DTfloat) d8 d6)%R) by - (apply IHdf2; trivial; discriminate). - inversion H13; inversion H14. - rewrite H16; rewrite H17; lra. - -- unfold matrix_vector_mult. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vsum_mult; f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - simpl; lra. - * apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - + now rewrite H6 in H0. - + rewrite H2 in H; rewrite H1 in H. - now rewrite H5 in H. - - Case "MatrixVectorAdd"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - specialize (IHdf1 grad grad_env1 grad_env2); intros. - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i j => (c * grad i j)%R)) - ; intros. - + rewrite H3 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros. - * rewrite H4 in H0; simpl in H0. - match_option_in H0; [|tauto]. - match_option_in H; [|tauto]. - rewrite H3 in IHdf1. - unfold lift. - f_equal. - rewrite H4 in IHdf1. - unfold lift in IHdf1. - assert (vartlookup d1 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1). - case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2). - case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros. - rewrite (split_subvar d1 d4 d0 d5) by trivial. - rewrite (split_subvar d2 d3 d d6) by trivial. - assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) - by (apply IHdf1; trivial; discriminate). - inversion H9; rewrite H11. - assert (Some (subvar (s, DTfloat) d4 d5) = Some (c * subvar (s, DTfloat) d3 d6)%R). - -- f_equal. - apply (scalarMult_list_env_iter - s c d5 d6 {m' : nat | m' < n} - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv - σ df2 env - (transpose - (fun (i0 : {n' : nat | n' < m}) - (j : {m' : nat | m' < n}) => - (c * grad i0 j)%R) i)) - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad i)) - (bounded_seq0 n) d1 d2 d4 d3); trivial. - ++ intros. - assert (vartlookup env1 (s, DTfloat) <> None); [congruence|]. - assert (vartlookup env2 (s, DTfloat) <> None); [congruence|]. - assert (df_eval_backprop_deriv - σ df2 env1 - (transpose - (fun (i0 : {n' : nat | n' < m}) - (j : {m' : nat | m' < n}) => (c * grad i0 j)%R) i) <> None) - ; [congruence|]. - assert (df_eval_backprop_deriv σ df2 env2 (transpose grad i) <> None) - ;[congruence|]. - specialize (IHdf2 (transpose grad i) env1 env2). - specialize (IHdf2 H15 H16). - specialize (IHdf2 H17 H18). - unfold lift in IHdf2; simpl in IHdf2. - rewrite H10, H12, H14 in IHdf2. - unfold transpose in IHdf2; unfold transpose in H13. - rewrite H13 in IHdf2; now inversion IHdf2. - ++ intros. - now apply (df_eval_backprop_deriv_preserves_lookup_not_none H10). - ++ intros. - now apply (df_eval_backprop_deriv_preserves_lookup_not_none H10). - -- inversion H10; rewrite H13; lra. - * rewrite H4 in H0; tauto. - + rewrite H3 in H; tauto. - - Case "MatrixMult"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - case_eq ( df_eval σ df1); [|tauto]. - case_eq ( df_eval σ df2); [|tauto]; intros. - specialize (IHdf1 (matrix_mult grad (fun i j => d j i)) grad_env1 grad_env2). - rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 - (matrix_mult (fun i j => (c * grad i j)%R) - (fun i j => d j i))); intros. - rewrite H1 in H; rewrite H2 in H; simpl in H. - rewrite H1 in H0; rewrite H2 in H0; simpl in H0. - rewrite H5 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 - (matrix_mult grad (fun i j => d j i))); intros. - rewrite H6 in H0; simpl in H0. - + specialize (IHdf2 (matrix_mult (fun i j => d0 j i) grad) d3 d4). - replace (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < p}) => - (c * - (@matrix_mult floatish_R m n p grad - (fun (i0 : {n' : nat | (n' < n)%nat}) (j0 : {m' : nat | (m' < p)%nat}) => - d j0 i0)) i j)%R) with - (matrix_mult (fun i j => (c * grad i j)%R) - (fun i j => d j i)) in IHdf1. - * rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1). - case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2). - case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros. - rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *. - unfold lift; match_case; [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d4 (matrix_mult (fun i j => d0 j i) grad)) - ; [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d3 d7 d2 d5) by trivial. - rewrite (split_subvar d4 d8 d1 d6) by trivial. - replace (fun (i : {n' : nat | n' < p}) (j : {m' : nat | m' < n}) => - (c * - (@matrix_mult floatish_R p m n - (fun (i0 : {n' : nat | (n' < p)%nat}) - (j0 : {m' : nat | (m' < m)%nat}) => - d0 j0 i0) grad) i j)%R) with - (matrix_mult (fun (i : {n' : nat | n' < p}) (j : {m' : nat | m' < m}) => d0 j i) - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (c * grad i j)%R)) in IHdf2. - -- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d3 d2) = - Some (c * subvar (s, DTfloat) d4 d1)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d7 d5) = - Some (c * subvar (s, DTfloat) d8 d6)%R) by - (apply IHdf2; trivial; discriminate). - inversion H13; inversion H14. - rewrite H16; rewrite H17; lra. - -- unfold matrix_mult. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vsum_mult; f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - simpl; lra. - * unfold matrix_mult. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vsum_mult; f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - simpl; lra. - + now rewrite H6 in H0. - + rewrite H2 in H; rewrite H1 in H. - now rewrite H5 in H. - - Case "VectorPlus"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - specialize (IHdf1 grad grad_env1 grad_env2); intros. - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i => (c * grad i)%R)); intros. - rewrite H3 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros. - rewrite H4 in H0; simpl in H0. - + specialize (IHdf2 grad d1 d2). - rewrite H3 in IHdf1; rewrite H4 in IHdf1. - assert (vartlookup d1 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1). - case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2). - case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros. - rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d1 (fun i => (c * grad i)%R)) - ; [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d2 grad); [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d1 d5 d0 d3) by trivial. - rewrite (split_subvar d2 d6 d d4) by trivial. - rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by - (apply IHdf2; trivial; discriminate). - inversion H11; inversion H12. - rewrite H14; rewrite H15; lra. - + now rewrite H4 in H0. - + now rewrite H3 in H. - - Case "VectorMinus"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - specialize (IHdf1 grad grad_env1 grad_env2); intros. - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i => (c * grad i )%R)) - ; intros. - rewrite H3 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros. - rewrite H4 in H0; simpl in H0. - + specialize (IHdf2 (fun i => (- grad i)%R) d1 d2). - rewrite H3 in IHdf1; rewrite H4 in IHdf1. - assert (vartlookup d1 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1). - case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2). - case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros. - rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d1 (fun i => (- (c * grad i))%R)) - ; [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d2 (fun i => (- grad i)%R)); [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d1 d5 d0 d3) by trivial. - rewrite (split_subvar d2 d6 d d4) by trivial. - replace (fun i => (c * - grad i)%R) with (fun i => (-( c * grad i))%R) in IHdf2. - * rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by - (apply IHdf2; trivial; discriminate). - inversion H11; inversion H12. - rewrite H14; rewrite H15; lra. - * apply FunctionalExtensionality.functional_extensionality; intros. - lra. - + now rewrite H4 in H0. - + now rewrite H3 in H. - - Case "MatrixPlus"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - specialize (IHdf1 grad grad_env1 grad_env2); intros. - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i j => (c * grad i j)%R)) - ; intros. - rewrite H3 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros. - rewrite H4 in H0; simpl in H0. - + specialize (IHdf2 grad d1 d2). - rewrite H3 in IHdf1; rewrite H4 in IHdf1. - assert (vartlookup d1 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1). - case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2). - case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros. - rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d1 (fun i j => (c * grad i j)%R)) - ; [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d2 grad); [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d1 d5 d0 d3) by trivial. - rewrite (split_subvar d2 d6 d d4) by trivial. - rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by - (apply IHdf2; trivial; discriminate). - inversion H11; inversion H12. - rewrite H14; rewrite H15; lra. - + now rewrite H4 in H0. - + now rewrite H3 in H. - - Case "MatrixMinus"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - specialize (IHdf1 grad grad_env1 grad_env2); intros. - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i j => (c * grad i j)%R)) - ; intros. - rewrite H3 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros. - rewrite H4 in H0; simpl in H0. - + specialize (IHdf2 (fun i j => (- grad i j)%R) d1 d2). - rewrite H3 in IHdf1; rewrite H4 in IHdf1. - assert (vartlookup d1 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1). - case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2). - case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros. - rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d1 (fun i j => (- (c * grad i j))%R)) - ; [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d2 (fun i j => (- grad i j)%R)); [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d1 d5 d0 d3) by trivial. - rewrite (split_subvar d2 d6 d d4) by trivial. - replace (fun i j => (c * - grad i j)%R) with - (fun i j => (-( c * grad i j))%R) in IHdf2. - * rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by - (apply IHdf2; trivial; discriminate). - inversion H11; inversion H12. - rewrite H14; rewrite H15; lra. - * apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - + now rewrite H4 in H0. - + now rewrite H3 in H. - - Case "VectorScalMult"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - case_eq ( df_eval σ df1); [|tauto]. - case_eq ( df_eval σ df2); [|tauto]; intros. - specialize (IHdf1 (vsum (fun j => (d j * grad j)%R)) grad_env1 grad_env2). - rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 - (vsum (fun j => (d j * (c * grad j))%R))) - - ; intros. - rewrite H1 in H; rewrite H2 in H; simpl in H. - rewrite H1 in H0; rewrite H2 in H0; simpl in H0. - rewrite H5 in H; simpl in H. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 - (vsum (fun j => (d j * grad j)%R))) - ; intros. - rewrite H6 in H0; simpl in H0. - + specialize (IHdf2 (fun j => (grad j * d0)%R) d3 d4). - replace - (c * - (@vsum floatish_R _ - (fun (j : {n' : nat | (n' < n)%nat}) => - d j * grad j)))%R with - (vsum - (fun (j : {n' : nat | n' < n}) => - (d j * (c * grad j))%R)) in IHdf1. - * rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1). - case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2). - case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros. - rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d3 - (fun (j : {n' : nat | n' < n}) => - (d0 * (c * grad j))%R)) - ; [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d4 (fun j => (d0 * grad j)%R)) - ; [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d3 d7 d2 d5) by trivial. - rewrite (split_subvar d4 d8 d1 d6) by trivial. - replace (fun i => (c * (grad i * d0))%R) with - (fun j => (d0 * (c * grad j))%R) in IHdf2. - replace (fun j => (grad j * d0)%R) with (fun j => (d0 * grad j)%R) in IHdf2. - -- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d3 d2) = - Some (c * subvar (s, DTfloat) d4 d1)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d7 d5) = - Some (c * subvar (s, DTfloat) d8 d6)%R) by - (apply IHdf2; trivial; discriminate). - inversion H13; inversion H14. - rewrite H16; rewrite H17; lra. - -- apply FunctionalExtensionality.functional_extensionality; intros. - lra. - -- apply FunctionalExtensionality.functional_extensionality; intros. - lra. - * rewrite vsum_mult; f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - + now rewrite H6 in H0. - + rewrite H2 in H; rewrite H1 in H. - now rewrite H5 in H. - - Case "MatrixScalMult"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]. - case_eq ( df_eval σ df1); [|tauto]. - case_eq ( df_eval σ df2); [|tauto]; intros. - specialize (IHdf1 (msum (fun i j => (d i j * grad i j)%R)) grad_env1 grad_env2). - rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *. - case_eq (df_eval_backprop_deriv σ df1 grad_env1 - (msum (fun i j => (d i j * (c * grad i j))%R))) - - ; intros. - rewrite H1, H2, H5 in H; simpl in H. - rewrite H1, H2 in H0; simpl in H0. - case_eq (df_eval_backprop_deriv σ df1 grad_env2 - (msum (fun i j => (d i j * grad i j)%R))) - ; intros. - rewrite H6 in H0; simpl in H0. - + specialize (IHdf2 (fun i j => (grad i j * d0)%R) d3 d4). - replace - (c * - (@msum floatish_R _ _ - (fun (i : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) => - d i j * grad i j)))%R with - - (msum - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (d i j * (c * grad i j))%R)) in IHdf1. - * rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1). - case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros. - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2). - case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros. - rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *. - case_eq (df_eval_backprop_deriv σ df2 d3 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (c * grad i j * d0)%R)) - ; [|tauto]; intros. - case_eq (df_eval_backprop_deriv σ df2 d4 (fun i j => (grad i j * d0)%R)) - ; [|tauto]; intros; simpl; f_equal. - rewrite (split_subvar d3 d7 d2 d5) by trivial. - rewrite (split_subvar d4 d8 d1 d6) by trivial. - replace (fun i j => (c * (grad i j * d0))%R) with - (fun i j => (c * grad i j * d0)%R) in IHdf2. - -- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *. - assert (Some (subvar (s, DTfloat) d3 d2) = - Some (c * subvar (s, DTfloat) d4 d1)%R) by - (apply IHdf1; trivial; discriminate). - assert (Some (subvar (s, DTfloat) d7 d5) = - Some (c * subvar (s, DTfloat) d8 d6)%R) by - (apply IHdf2; trivial; discriminate). - inversion H13; inversion H14. - rewrite H16; rewrite H17; lra. - -- apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - * unfold msum. - rewrite vsum_mult; f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vmap_nth. - rewrite vmap_nth. - rewrite vsum_mult. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - + now rewrite H6 in H0. - + rewrite H2 in H; rewrite H1 in H. - now rewrite H5 in H. - - Case "VectorApply"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - simpl in *. - case_eq (df_eval σ df2); [ | tauto]. - intros. - rewrite H3 in H; rewrite H3 in H0; simpl in *. - match_option_in H0; [|tauto]. - specialize (IHdf1 v0 grad_env1 grad_env2). - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in IHdf1. - match_option_in H; [|tauto]. - vectoro_assert_forall_in eqq i. - apply vectoro_to_ovector_forall_some_f; trivial. - vectoro_assert_forall_in eqq0 i. - apply vectoro_to_ovector_forall_some_f; trivial. - assert (v1 = (fun i => (c * v0 i)%R)). - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H4 x). - specialize (H5 x). - rewrite vmap_nth in H4; simpl in H4. - rewrite vmap_nth in H5; simpl in H5. - match_option_in H4. - match_option_in H5. - inversion H4; inversion H5; subst. - assert (Some d2 = Some d3). - rewrite <- eqq1. - rewrite <- eqq2; trivial. - inversion H6; subst; lra. - subst. - apply IHdf1; trivial; discriminate. - - Case "MatrixApply"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - simpl in *. - case_eq (df_eval σ df2); [ | tauto]. - intros. - rewrite H3 in H; rewrite H3 in H0; simpl in *. - match_option_in H0; [|tauto]. - match_option_in H; [|tauto]. - unfold matrixo_to_omatrix in eqq. - unfold matrixo_to_omatrix in eqq0. - vectoro_assert_forall_in eqq i. - apply vectoro_to_ovector_forall_some_f; trivial. - vectoro_assert_forall_in eqq0 i. - apply vectoro_to_ovector_forall_some_f; trivial. - assert (m1 = (fun i j => (c * m0 i j)%R)). - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H4 x); specialize (H5 x); simpl in H4; simpl in H5. - vectoro_assert_forall_in H4 j. - apply vectoro_to_ovector_forall_some_f; trivial. - vectoro_assert_forall_in H5 j. - apply vectoro_to_ovector_forall_some_f; trivial. - specialize (H6 x0); specialize (H7 x0). - unfold mmap in H6; unfold mmap in H7. - rewrite vmap_nth in H6; rewrite vmap_nth in H6; simpl in H6. - rewrite vmap_nth in H7; rewrite vmap_nth in H7; simpl in H7. - match_case_in H6; intros. - rewrite H8 in H6; simpl in H6. - match_case_in H7; intros. - rewrite H9 in H7; simpl in H7. - match_option_in H6. - match_option_in H7. - inversion H6; inversion H7. - unfold matrix_zip in H8. - unfold matrix_zip in H9. - rewrite vmap_nth in H8. - rewrite vmap_nth in H9. - unfold vector_zip in H8. - unfold vector_zip in H9. - inversion H8; subst r r0. - inversion H9; subst r1 r2. - assert (Some d2 = Some d3). - rewrite <- eqq1; rewrite <- eqq2; trivial. - inversion H10; subst; lra. - specialize (IHdf1 m0 grad_env1 grad_env2). - rewrite H1, H2 in IHdf1; simpl in IHdf1. - subst m1. - apply IHdf1; trivial; discriminate. - - Case "VLossfun"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - simpl in *. - case_eq (df_eval σ df2); [ | tauto]. - intros. - rewrite H3 in H; rewrite H3 in H0; simpl in *. - match_option_in H0; [|tauto]. - match_option_in H; [|tauto]. - vectoro_assert_forall_in eqq i. - apply vectoro_to_ovector_forall_some_f; trivial. - vectoro_assert_forall_in eqq0 i. - apply vectoro_to_ovector_forall_some_f; trivial. - assert (v0 = (fun i => (c * v i)%R)). - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H4 x); specialize (H5 x). - rewrite vmap_nth in H4; simpl in H4. - rewrite vmap_nth in H5; simpl in H5. - match_option_in H4. - match_option_in H5. - assert (Some d2 = Some d3). - rewrite <- eqq1; rewrite <- eqq2; trivial. - inversion H6; subst. - inversion H4; inversion H5; lra. - subst. - specialize (IHdf1 v grad_env1 grad_env2). - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in IHdf1. - apply IHdf1; trivial; discriminate. - - Case "MLossfun"%string. - case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros. - case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros. - simpl in *. - case_eq (df_eval σ df2); [ | tauto]. - intros. - rewrite H3 in H; rewrite H3 in H0; simpl in *. - match_option_in H0; [|tauto]. - match_option_in H; [|tauto]. - unfold matrixo_to_omatrix in eqq. - unfold matrixo_to_omatrix in eqq0. - vectoro_assert_forall_in eqq i. - apply vectoro_to_ovector_forall_some_f; trivial. - vectoro_assert_forall_in eqq0 i. - apply vectoro_to_ovector_forall_some_f; trivial. - assert (m1 = (fun i j => (c * m0 i j)%R)). - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H4 x); specialize (H5 x); simpl in H4; simpl in H5. - vectoro_assert_forall_in H4 j. - apply vectoro_to_ovector_forall_some_f; trivial. - vectoro_assert_forall_in H5 j. - apply vectoro_to_ovector_forall_some_f; trivial. - specialize (H6 x0); specialize (H7 x0). - unfold mmap in H6; unfold mmap in H7. - rewrite vmap_nth in H6; rewrite vmap_nth in H6; simpl in H6. - rewrite vmap_nth in H7; rewrite vmap_nth in H7; simpl in H7. - match_destr_in H6. - match_option_in H6. - match_option_in H7. - assert (Some d2 = Some d3). - rewrite <- eqq1; rewrite <- eqq2; trivial. - inversion H8; subst. - inversion H6; inversion H7. - lra. - rewrite H6. - specialize (IHdf1 m0 grad_env1 grad_env2). - rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in IHdf1. - subst. - apply IHdf1; trivial; discriminate. - Qed. - - Ltac simpl_closed_backprop := - match goal with - | [|- context [ - match df_eval_backprop_deriv ?σ ?df1 ?grad_env1 ?grad with - | Some _ => _ - | None => _ - end]] => case_eq (df_eval_backprop_deriv σ df1 grad_env1 grad) - ; [let env := fresh "env" in let eqq := fresh "eqq" in intros env eqq | - let eqq := fresh "eqq" in - intros eqq; - eelim backprop_deriv_fully_closed_not_none; [clear eqq | eapply eqq]; trivial - ] - end. - - Ltac simpler2 := - trivial; - repeat - match goal with - | [ |- Some _ <> None ] => congruence - | [ |- None <> Some _ ] => congruence - - | [H:vartlookup ?grad_env ?a <> None - |- context [vartlookup ?grad_env ?a]] => - case_eq (vartlookup grad_env a); [ - let d := fresh "val" in - let eqq := fresh "eqq" in - intros d eqq; simpl in d - | intros ?; eelim H; solve[auto]] - | [H: df_eval_backprop_deriv ?σ ?df1 ?grad_env1 _ = Some ?grad_env2 - |- context [vartlookup ?grad_env2 ?a]] => - case_eq (vartlookup grad_env2 a); [ - let d := fresh "val" in - let eqq := fresh "eqq" in - intros d eqq; simpl in d - | let eqq := fresh "eqq" in - intros eqq; eelim df_eval_backprop_deriv_preserves_lookup_not_none; [apply H | idtac | apply eqq]; solve[auto] - ] - - | [H:vartlookup ?grad_env ?a <> None - |- context [vartlookup ?grad_env ?a]] => - case_eq (vartlookup grad_env a); [ - let d := fresh "val" in - let eqq := fresh "eqq" in - intros d eqq; simpl in d - | intros ?; eelim H; solve[auto || congruence]] - | [H: df_eval_backprop_deriv ?σ ?df1 ?grad_env1 _ = Some ?grad_env2 - |- context [vartlookup ?grad_env2 ?a]] => - case_eq (vartlookup grad_env2 a); [ - let d := fresh "val" in - let eqq := fresh "eqq" in - intros d eqq; simpl in d - | let eqq := fresh "eqq" in - intros eqq; eelim df_eval_backprop_deriv_preserves_lookup_not_none; [apply H | idtac | apply eqq]; solve[auto || congruence] - ] - | [H:vartlookup ?grad_env ?a <> None, - H2:context [match vartlookup ?grad_env ?a with | _ => _ end] |- _] => - case_eq (vartlookup grad_env a); [ - let d := fresh "val" in - let eqq := fresh "eqq" in - intros d eqq; simpl in d; rewrite eqq in H2 - | intros ?; eelim H; solve[auto || congruence]] - | [H: df_eval_backprop_deriv ?σ ?df1 ?grad_env1 _ = Some ?grad_env2, - H2: context [match vartlookup ?grad_env2 ?a with _ => _ end] |- _] => - case_eq (vartlookup grad_env2 a); [ - let d := fresh "val" in - let eqq := fresh "eqq" in - intros d eqq; simpl in d; rewrite eqq in H2 - | let eqq := fresh "eqq" in - intros eqq; eelim df_eval_backprop_deriv_preserves_lookup_not_none; [apply H | idtac | apply eqq]; solve[auto || congruence]] - - end. - - Lemma backprop_indep_env {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (s:SubVar) - (grad_env1 grad_env2:df_env) (grad : definition_function_types_interp T) : - let v := (s, DTfloat) in - vartlookup grad_env1 v <> None -> - vartlookup grad_env2 v <> None -> - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> - df_eval_backprop_delta σ df v grad_env1 grad = - df_eval_backprop_delta σ df v grad_env2 grad. - Proof. - revert grad_env1 grad_env2. - unfold df_eval_backprop_delta. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case - ; simpl; intros grad_env1 grad_env2 neq1 neq2; intros. - - Case "Number"%string. - match_option; [|tauto]. - match_option; [|tauto]. - unfold snd in *. - f_equal. - unfold subvar; simpl. - rewrite eqq, eqq0. - lra. - - Case "Constant"%string. - match_option; [|tauto]. - match_option; [|tauto]. - unfold snd in *. - f_equal. - unfold subvar; simpl. - rewrite eqq, eqq0. - lra. - - Case "DVector"%string. - rewrite vforall_forall in H0. - unfold two_vector_env_iter_alt. - match_option; [|tauto]. - match_option; [|tauto]. - unfold lift; simpl. - match_option. - + match_option. - f_equal. - apply (list_env_iter_subvar_env2 - s d d0 {n' : nat | n' < n} - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (grad i)) - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (grad i)) - (bounded_seq0 n) grad_env1 grad_env2 d1 d2); trivial. - * intros. - specialize (H i (grad i) env1 env2). - cut_to H; try congruence; eauto 3. - rewrite H1, H2, H3, H4 in H. - unfold lift in H. - now inversion H. - * intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H1); trivial. - * intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H1); trivial. - * assert (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (grad i)) (Some grad_env2) - (bounded_seq0 n) <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - + assert (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (grad i)) (Some grad_env1) - (bounded_seq0 n) <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - - Case "DMatrix"%string. - rewrite vforall_forall in H0. - unfold two_matrix_env_iter_alt. - match_option; [|tauto]. - match_option; [|tauto]. - unfold lift; simpl. - match_option. - + match_option. - * f_equal. - apply (list_env_iter_subvar_env2 - s d d0 {n' : nat | n' < n} - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (Some env) (bounded_seq0 m)) - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (Some env) (bounded_seq0 m)) - (bounded_seq0 n) grad_env1 grad_env2 d1 d2); trivial. - -- intros. - apply (list_env_iter_subvar_env2 - s v1 v2 {m' : nat | m' < m} - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (bounded_seq0 m) env1 env2 fenv1 fenv2); trivial. - ++ intros. - specialize (H i i0 (grad i i0) env0 env3). - cut_to H. - rewrite H5, H6, H7, H8 in H. - unfold lift in H. - now inversion H. - congruence. - congruence. - specialize (H0 i). - rewrite vforall_forall in H0. - apply (H0 i0). - ++ intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5); trivial. - ++ intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5); trivial. - -- intros. - apply (vartlookup_list_env_iter - s - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (bounded_seq0 m) env fenv); trivial. - intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H3); trivial. - -- intros. - apply (vartlookup_list_env_iter - s - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (bounded_seq0 m) env fenv); trivial. - intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H3); trivial. - * assert (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (Some env) (bounded_seq0 m)) (Some grad_env2) (bounded_seq0 n) - <> None). - apply list_env_iter_total_fun; intros. - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - specialize (H0 a). - rewrite vforall_forall in H0. - apply (H0 a0). - tauto. - + assert (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad i j)) - (Some env) (bounded_seq0 m)) (Some grad_env1) (bounded_seq0 n) - <> None). - apply list_env_iter_total_fun; intros. - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - specialize (H0 a). - rewrite vforall_forall in H0. - apply (H0 a0). - tauto. - - Case "Var"%string. - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - unfold lift, subvar; simpl. - destruct (v == (s,DTfloat)). - + invcs e. - unfold addvar; simpl. - rewrite eqq, H0. - rewrite lookup_update. - rewrite lookup_update. - f_equal; lra. - + assert (v<> (s, DTfloat)) by congruence. - case_eq (vartlookup grad_env1 v); intros. - * rewrite lookup_update_neq; trivial. - rewrite eqq. - case_eq (vartlookup grad_env2 v); intros. - -- rewrite lookup_update_neq; trivial. - rewrite H0; f_equal; lra. - -- rewrite H0; f_equal; lra. - * rewrite eqq. - case_eq (vartlookup grad_env2 v); intros. - -- rewrite lookup_update_neq; trivial. - rewrite H0; f_equal; lra. - -- rewrite H0; f_equal; lra. - - Case "Plus"%string. - destruct H. - unfold lift. - repeat simpl_closed_backprop. - simpler2. - f_equal. - specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1,eqq3,eqq4 in IHdf1. - specialize (IHdf2 grad env env1). - cut_to IHdf2; simpler2. - unfold lift in IHdf1; inversion IHdf1. - rewrite eqq2, eqq0 in IHdf2; unfold lift in IHdf2; inversion IHdf2. - rewrite (split_subvar env env0 val val1); trivial. - rewrite (split_subvar env1 env2 val0 val2); trivial. - rewrite H2, H3; lra. - - Case "Minus"%string. - destruct H. - unfold lift. - repeat simpl_closed_backprop. - simpler2. - f_equal. - specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1,eqq3,eqq4 in IHdf1. - specialize (IHdf2 (-grad)%R env env1). - cut_to IHdf2; simpler2. - unfold lift in IHdf1; invcs IHdf1. - rewrite eqq0,eqq2 in IHdf2; unfold lift in IHdf2; invcs IHdf2. - rewrite (split_subvar env env0 val val1); trivial. - rewrite (split_subvar env1 env2 val0 val2); trivial. - rewrite H2, H3; lra. - - Case "Times"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df1 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - unfold lift. - repeat simpl_closed_backprop. - simpler2. - f_equal. - specialize (IHdf1 (d1 * grad)%R grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq6,eqq2,eqq4 in IHdf1. - specialize (IHdf2 (d0 * grad)%R env env1). - cut_to IHdf2; simpler2. - unfold lift in IHdf1; invcs IHdf1. - rewrite eqq5, eqq3 in IHdf2; unfold lift in IHdf2; invcs IHdf2. - rewrite (split_subvar env env0 d val0); trivial. - rewrite (split_subvar env1 env2 val val1); trivial. - rewrite H4, H5; lra. - - Case "Divide"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df1 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - unfold lift. - repeat simpl_closed_backprop. - simpler2. - f_equal. - specialize (IHdf1 (grad/d1)%R grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq6,eqq2,eqq4 in IHdf1. - specialize (IHdf2 (-d0/(d1*d1) * grad)%R env env1). - cut_to IHdf2; simpler2. - unfold lift in IHdf1; inversion IHdf1. - rewrite eqq5, eqq3 in IHdf2; unfold lift in IHdf2; inversion IHdf2. - rewrite (split_subvar env env0 d val0); trivial. - rewrite (split_subvar env1 env2 val val1); trivial. - rewrite H4, H5; lra. - - Case "Square"%string. - match_option; [|tauto]. - assert (df_eval σ df <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - match_option; [|tauto]. - unfold lift. - repeat simpl_closed_backprop. - f_equal. - specialize (IHdf (2 * d0 * grad)%R grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1,eqq2,eqq3 in IHdf. - now unfold lift in IHdf; inversion IHdf. - - Case "Exp"%string. - match_option; [|tauto]. - assert (df_eval σ df <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - match_option; [|tauto]. - unfold lift. - repeat simpl_closed_backprop. - f_equal. - specialize (IHdf (grad * exp d0)%R grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1,eqq2, eqq3 in IHdf. - now unfold lift in IHdf; inversion IHdf. - - Case "Log"%string. - match_option; [|tauto]. - assert (df_eval σ df <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df grad_env1 (grad/d0)%R <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - unfold lift. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df grad_env2 (grad/d0)%R <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - specialize (IHdf (grad / d0)%R grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1 in IHdf. - now rewrite eqq2, eqq3 in IHdf; unfold lift in IHdf; inversion IHdf. - - Case "Abs"%string. - match_option; [|tauto]. - assert (df_eval σ df <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df grad_env1 (grad * sign d0)%R <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - unfold lift. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df grad_env2 (grad * sign d0)%R <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - specialize (IHdf (grad * sign d0)%R grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1 in IHdf. - now rewrite eqq2, eqq3 in IHdf; unfold lift in IHdf; inversion IHdf. - - Case "Sign"%string. - match_option; [|tauto]. - assert (df_eval σ df <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df grad_env1 (0)%R <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - unfold lift. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df grad_env2 (0)%R <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - specialize (IHdf (0)%R grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq0 in IHdf. - now rewrite eqq1, eqq2 in IHdf; unfold lift in IHdf; inversion IHdf. - - Case "PSign"%string. - match_option; [|tauto]. - assert (df_eval σ df <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df grad_env1 (0)%R <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - unfold lift. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df grad_env2 (0)%R <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - specialize (IHdf (0)%R grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq0 in IHdf. - now rewrite eqq1, eqq2 in IHdf; unfold lift in IHdf; inversion IHdf. - - Case "Max"%string. - destruct H. - specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H). - specialize (IHdf2 grad grad_env1 grad_env2 neq1 neq2 H0). - match_option; [|tauto]. - assert (df_eval σ df1 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - match_option; [|tauto]. - rewrite eqq, eqq2 in IHdf1. - rewrite eqq, eqq2 in IHdf2. - destruct (Rle_dec d0 d1); trivial. - - Case "VectorDot"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df1 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - specialize (IHdf1 (vmap (fun rv : R => (rv * grad)%R) d1) - grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, H3 in IHdf1. - assert (df_eval_backprop_deriv σ df1 grad_env1 - (vmap (fun rv : R => (rv * grad)%R) d1) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 - (vmap (fun rv : R => (rv * grad)%R) d1) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1). - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2). - specialize (IHdf2 (vmap (fun lv : R => (lv * grad)%R) d0) - d3 d4 H6 H7 H0). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - unfold lift. - assert (df_eval_backprop_deriv σ df2 d3 - (vmap (fun lv : R => (lv * grad)%R) d0) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d4 - (vmap (fun lv : R => (lv * grad)%R) d0) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - unfold lift in IHdf1. - rewrite eqq2, eqq3 in IHdf1; inversion IHdf1. - rewrite eqq6, eqq7 in IHdf2; inversion IHdf2. - rewrite (split_subvar d3 d7 d d5); trivial. - rewrite (split_subvar d4 d8 d2 d6); trivial. - now rewrite H11, H12. - - Case "VectorSum"%string. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (ConstVector n grad) grad_env1 grad_env2 neq1 neq2 H). - now rewrite eqq, eqq0 in IHdf. - - Case "MatrixSum"%string. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (ConstMatrix m n grad) grad_env1 grad_env2 neq1 neq2 H). - now rewrite eqq, eqq0 in IHdf. - - Case "VectorElem"%string. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (fun k : {n' : nat | n' < n} => if equiv_dec (` k) (` i) - then grad else 0%R) - grad_env1 grad_env2 neq1 neq2 H). - now rewrite eqq, eqq0 in IHdf. - - Case "MatrixElem"%string. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) => - if equiv_dec (` k1) (` i) then - if equiv_dec (` k2) (` j) then grad else 0%R else 0%R) - grad_env1 grad_env2 neq1 neq2 H). - now rewrite eqq, eqq0 in IHdf. - - Case "MatrixVectorMult"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df1 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - specialize (IHdf1 (fun (i : {n' : nat | n' < m}) - (j : {m' : nat | m' < n}) => (grad i * d1 j)%R) - grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, H3 in IHdf1. - assert (df_eval_backprop_deriv σ df1 grad_env1 - (fun (i : {n' : nat | n' < m}) - (j : {m' : nat | m' < n}) => (grad i * d1 j)%R) - <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 - (fun (i : {n' : nat | n' < m}) - (j : {m' : nat | m' < n}) => (grad i * d1 j)%R) - <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1). - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2). - specialize (IHdf2 - (matrix_vector_mult - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => d0 j i) grad) - d3 d4 H6 H7 H0). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - unfold lift. - assert (df_eval_backprop_deriv - σ df2 d3 - (matrix_vector_mult (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => d0 j i) - grad) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d4 - (matrix_vector_mult (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => d0 j i) - grad) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - unfold lift in IHdf1. - unfold lift in IHdf2. - rewrite eqq2, eqq3 in IHdf1; inversion IHdf1. - rewrite eqq6, eqq7 in IHdf2; inversion IHdf2. - rewrite (split_subvar d3 d7 d d5); trivial. - rewrite (split_subvar d4 d8 d2 d6); trivial. - now rewrite H11, H12. - - Case "MatrixVectorAdd"%string. - destruct H. - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - specialize (IHdf1 grad - grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, H1 in IHdf1. - assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); [intros | tauto]. - rewrite eqq0, H4 in IHdf1. - unfold lift in IHdf1. - inversion IHdf1. - assert (list_env_iter - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad i)) (Some d1) - (bounded_seq0 n) <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - match_option; [|tauto]. - assert (list_env_iter - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad i)) (Some d2) - (bounded_seq0 n) <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - match_option; [|tauto]. - unfold lift; f_equal. - assert (vartlookup d1 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1). - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2). - assert (vartlookup d3 (s, DTfloat) <> None). - apply - (vartlookup_list_env_iter - s - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad i)) - (bounded_seq0 n) d1); trivial. - intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H10 (s, DTfloat) H11). - assert (vartlookup d4 (s, DTfloat) <> None). - apply - (vartlookup_list_env_iter - s - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad i)) - (bounded_seq0 n) d2); trivial. - intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H11 (s, DTfloat) H12). - unfold subvar in IHdf1; simpl in IHdf1. - match_option_in IHdf1; [|tauto]. - match_option_in IHdf1; [|tauto]. - rewrite (split_subvar d1 d3 d d5); trivial. - rewrite (split_subvar d2 d4 d0 d6); trivial. - rewrite H6. - apply Rplus_eq_compat_r . - apply (list_env_iter_subvar_env2 - s d5 d6 {m' : nat | m' < n} - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad i)) - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad i)) - (bounded_seq0 n) d1 d2 d3 d4); trivial. - intros. - specialize (IHdf2 (transpose grad i) env1 env2). - rewrite H12, H13 in IHdf2. - cut_to IHdf2; trivial. - rewrite H14, H15 in IHdf2. - unfold lift in IHdf2. - now inversion IHdf2. - discriminate. - discriminate. - intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H12 (s, DTfloat) H13). - intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H12 (s, DTfloat) H13). - - Case "MatrixMult"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df1 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - specialize (IHdf1 - (matrix_mult grad (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < p}) => d1 j i)) - grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, H3 in IHdf1. - assert (df_eval_backprop_deriv σ df1 grad_env1 - (matrix_mult grad (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < p}) => d1 j i)) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 - (matrix_mult grad (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < p}) => d1 j i)) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - rewrite eqq2, eqq3 in IHdf1. - unfold lift in IHdf1; inversion IHdf1. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1). - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2). - specialize (IHdf2 - (matrix_mult (fun (i : {n' : nat | n' < p}) - (j : {m' : nat | m' < m}) => d0 j i) grad) - d3 d4 H6 H8 H0). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - unfold lift. - assert (df_eval_backprop_deriv - σ df2 d3 - (matrix_mult (fun (i : {n' : nat | n' < p}) - (j : {m' : nat | m' < m}) => d0 j i) grad) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv - σ df2 d4 - (matrix_mult (fun (i : {n' : nat | n' < p}) - (j : {m' : nat | m' < m}) => d0 j i) grad) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - unfold lift in IHdf1. - unfold lift in IHdf2. - rewrite eqq6, eqq7 in IHdf2; inversion IHdf2. - rewrite (split_subvar d3 d7 d d5); trivial. - rewrite (split_subvar d4 d8 d2 d6); trivial. - now rewrite H7, H12. - - Case "VectorPlus"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - unfold lift. - assert (df_eval_backprop_deriv σ df2 d0 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d2 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1 in IHdf1. - specialize (IHdf2 grad d0 d2). - assert (vartlookup d0 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1). - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq2). - specialize (IHdf2 H5 H6 H0). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - rewrite eqq0, eqq2 in IHdf1; unfold lift in IHdf1; inversion IHdf1. - rewrite eqq3, eqq4 in IHdf2; unfold lift in IHdf2; inversion IHdf2. - rewrite (split_subvar d0 d3 d d5); trivial. - rewrite (split_subvar d2 d4 d1 d6); trivial. - lra. - - Case "VectorMinus"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - unfold lift. - assert (df_eval_backprop_deriv - σ df2 d0 - (fun i : {n' : nat | n' < n} => (- grad i)%R) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d2 - (fun i : {n' : nat | n' < n} => (- grad i)%R) <> None) - by (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1 in IHdf1. - specialize (IHdf2 (fun i : {n' : nat | n' < n} => (- grad i)%R) d0 d2). - assert (vartlookup d0 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1). - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq2). - specialize (IHdf2 H5 H6 H0). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - rewrite eqq0, eqq2 in IHdf1; unfold lift in IHdf1; inversion IHdf1. - rewrite eqq3, eqq4 in IHdf2; unfold lift in IHdf2; inversion IHdf2. - rewrite (split_subvar d0 d3 d d5); trivial. - rewrite (split_subvar d2 d4 d1 d6); trivial. - lra. - - Case "MatrixPlus"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - unfold lift. - assert (df_eval_backprop_deriv σ df2 d0 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d2 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1 in IHdf1. - specialize (IHdf2 grad d0 d2). - assert (vartlookup d0 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1). - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq2). - specialize (IHdf2 H5 H6 H0). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - rewrite eqq0, eqq2 in IHdf1; unfold lift in IHdf1; inversion IHdf1. - rewrite eqq3, eqq4 in IHdf2; unfold lift in IHdf2; inversion IHdf2. - rewrite (split_subvar d0 d3 d d5); trivial. - rewrite (split_subvar d2 d4 d1 d6); trivial. - lra. - - Case "MatrixMinus"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - unfold lift. - assert (df_eval_backprop_deriv - σ df2 d0 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad i j)%R) - <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d2 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad i j)%R) - <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, eqq1 in IHdf1. - specialize (IHdf2 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad i j)%R) - d0 d2). - assert (vartlookup d0 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1). - assert (vartlookup d2 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq2). - specialize (IHdf2 H5 H6 H0). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - rewrite eqq0, eqq2 in IHdf1; unfold lift in IHdf1; inversion IHdf1. - rewrite eqq3, eqq4 in IHdf2; unfold lift in IHdf2; inversion IHdf2. - rewrite (split_subvar d0 d3 d d5); trivial. - rewrite (split_subvar d2 d4 d1 d6); trivial. - lra. - - Case "VectorScalMult"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df1 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - specialize (IHdf1 - (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad j)%R)) - grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, H3 in IHdf1. - assert (df_eval_backprop_deriv - σ df1 grad_env1 - (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad j)%R)) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 - (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad j)%R)) - <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1). - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2). - specialize (IHdf2 (fun j : {n' : nat | n' < n} => (d0 * grad j)%R) - d3 d4 H6 H7 H0). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - unfold lift. - assert (df_eval_backprop_deriv σ df2 d3 - (fun j : {n' : nat | n' < n} => (d0 * grad j)%R) - <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d4 - (fun j : {n' : nat | n' < n} => (d0 * grad j)%R) - <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - unfold lift in IHdf1. - rewrite eqq2, eqq3 in IHdf1; inversion IHdf1. - rewrite eqq6, eqq7 in IHdf2; inversion IHdf2. - rewrite (split_subvar d3 d7 d d5); trivial. - rewrite (split_subvar d4 d8 d2 d6); trivial. - now rewrite H11, H12. - - Case "MatrixScalMult"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df1 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - specialize (IHdf1 - (msum - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (d1 i j * grad i j)%R)) - grad_env1 grad_env2 neq1 neq2 H). - rewrite eqq, H3 in IHdf1. - assert (df_eval_backprop_deriv - σ df1 grad_env1 - (msum - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (d1 i j * grad i j)%R)) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 - (msum - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (d1 i j * grad i j)%R)) <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (vartlookup d3 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1). - assert (vartlookup d4 (s, DTfloat) <> None) by - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2). - specialize (IHdf2 (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad i j * d0)%R) - d3 d4 H6 H7 H0). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - unfold lift. - assert (df_eval_backprop_deriv σ df2 d3 - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad i j * d0)%R) - <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d4 - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad i j * d0)%R) - <> None) by - (apply backprop_deriv_fully_closed_not_none; trivial). - match_option; [|tauto]. - f_equal. - unfold lift in IHdf1. - rewrite eqq2, eqq3 in IHdf1; inversion IHdf1. - rewrite eqq6, eqq7 in IHdf2; inversion IHdf2. - rewrite (split_subvar d3 d7 d d5); trivial. - rewrite (split_subvar d4 d8 d2 d6); trivial. - now rewrite H11, H12. - - Case "VectorApply"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - match_option. - specialize (IHdf1 v0 grad_env1 grad_env2 neq1 neq2 H0). - now rewrite eqq, H2 in IHdf1. - - Case "MatrixApply"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - match_option. - specialize (IHdf1 m0 grad_env1 grad_env2 neq1 neq2 H0). - now rewrite eqq, H2 in IHdf1. - - Case "VLossfun"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - match_option. - specialize (IHdf1 v grad_env1 grad_env2 neq1 neq2 H0). - now rewrite eqq, H2 in IHdf1. - - Case "MLossfun"%string. - destruct H. - match_option; [|tauto]. - assert (df_eval σ df2 <> None) by - (apply eval_fully_closed_not_none;trivial). - match_option; [|tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto]. - match_option. - specialize (IHdf1 m0 grad_env1 grad_env2 neq1 neq2 H0). - now rewrite eqq, H2 in IHdf1. - Qed. - - Lemma backprop_exchange_order {T} (σ:df_env) (df1 df2 :DefinedFunction UnitAnn T) (s: SubVar) - (env:df_env) (grad1 grad2 : definition_function_types_interp T) : - let v := (s, DTfloat) in - vartlookup env v <> None -> - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df1 vl -> - fully_closed_over df2 vl -> - match - df_eval_backprop_deriv σ df1 env grad1, df_eval_backprop_deriv σ df2 env grad2, vartlookup env v with - | Some env1, Some env2, Some oval => - lift (fun e => subvar v e oval) - (df_eval_backprop_deriv σ df1 env2 grad1) = - lift (fun e => subvar v e oval) - (df_eval_backprop_deriv σ df2 env1 grad2) - | _, _, _ => True - end. - Proof. - intros. - do 3 match_option. - unfold lift. - do 2 match_option. - - assert (vartlookup d0 v <> None); simpler2. - assert (vartlookup d v <> None); simpler2. - case_eq (vartlookup d0 v); [intros|tauto]. - case_eq (vartlookup d v); [intros|tauto]. - f_equal. subst v. - rewrite (split_subvar d0 d2 d1 d4); trivial. - rewrite (split_subvar d d3 d1 d5); trivial. - generalize (backprop_indep_env σ df1 s env d0 grad1); intros. - generalize (backprop_indep_env σ df2 s env d grad2); intros. - simpl in H6; simpl in H7. - cut_to H6; trivial; try congruence. - cut_to H7; trivial; try congruence. - unfold df_eval_backprop_delta in *. - rewrite eqq1, H4 in H6. - rewrite eqq1, H5 in H7. - unfold lift in H6; simpl in H6. - unfold lift in H7; simpl in H7. - rewrite eqq, eqq2 in H6. - rewrite eqq0, eqq3 in H7. - inversion H6; inversion H7. - rewrite H9, H10; lra. - - assert (df_eval_backprop_deriv σ df2 d grad2 <> None). - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - - assert (df_eval_backprop_deriv σ df1 d0 grad1 <> None). - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - Qed. - - Lemma list_env_iter_id {A} (env : df_env) (l : list A) : - list_env_iter (fun (_ : A) (env : df_env) => Some env) - (Some env) l = Some env. - Proof. - now induction l. - Qed. - - Lemma backprop_grad_sum_list_env_iter {m} (σ:df_env) - (vecdf:Vector (DefinedFunction UnitAnn DTfloat) m) (s: SubVar) - (grad_env1 grad_env2 grad_env3:df_env) - (grad1 grad2 : (Vector float m)) - (val1 val2 val3 : float) - (l : list {m' | m' < m} ) - : - let v := (s, DTfloat) in - (forall (i:{m' | m' < m}), - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over (vecdf i) vl) -> - (forall (i:{m' | m' < m}) (env1 env2 env3 : df_env), - vartlookup env1 v <> None -> - vartlookup env2 v <> None -> - vartlookup env3 v <> None -> - - df_eval_backprop_delta σ (vecdf i) v env3 - (grad1 i + grad2 i)%R = - lift2 dfti_gen_plus - (df_eval_backprop_delta σ (vecdf i) v env1 (grad1 i)) - (df_eval_backprop_delta σ (vecdf i) v env2 (grad2 i))) -> - - vartlookup grad_env1 v = Some val1 -> - vartlookup grad_env2 v = Some val2 -> - vartlookup grad_env3 v = Some val3 -> - - lift (fun e : df_env => subvar v e val3) - (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (vecdf j) env (grad1 j + grad2 j)%R) - (Some grad_env3) l) = - lift2 dfti_gen_plus - (lift (fun e : df_env => subvar v e val1) - (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (vecdf j) env (grad1 j)) (Some grad_env1) l)) - (lift (fun e : df_env => subvar v e val2) - (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (vecdf j) env (grad2 j)) (Some grad_env2) l)). - Proof. - intros. - revert val1 val2 val3 grad_env1 grad_env2 grad_env3 H1 H2 H3. - induction l. - - intros. - simpl; f_equal. - unfold subvar; simpl. - rewrite H1,H2,H3. - lra. - - intros. - simpl. - unfold df_eval_backprop_delta in H0. - unfold lift, lift2. - assert (df_eval_backprop_deriv σ (vecdf a) grad_env3 (grad1 a + grad2 a)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H. - case_eq (df_eval_backprop_deriv σ (vecdf a) grad_env3 (grad1 a + grad2 a)%R) - ; [intros | tauto]. - assert (df_eval_backprop_deriv σ (vecdf a) grad_env1 (grad1 a)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H. - case_eq (df_eval_backprop_deriv σ (vecdf a) grad_env1 (grad1 a)%R) - ; [intros | tauto]. - assert (df_eval_backprop_deriv σ (vecdf a) grad_env2 (grad2 a)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H. - case_eq (df_eval_backprop_deriv σ (vecdf a) grad_env2 (grad2 a)%R) - ; [intros | tauto]. - - assert (vartlookup d v <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 v); congruence. - assert (vartlookup d0 v <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H7 v); congruence. - assert (vartlookup d1 v <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H9 v); congruence. - - case_eq (vartlookup d v); [intros v3 eq3 |tauto]. - case_eq (vartlookup d0 v); [intros v1 eq1 |tauto]. - case_eq (vartlookup d1 v); [intros v2 eq2 |tauto]. - - specialize (IHl v1 v2 v3 d0 d1 d eq1 eq2 eq3). - unfold lift, lift2. - match_option. - case_eq (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (vecdf j) env (grad1 j)) (Some d0) l). - case_eq (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (vecdf j) env (grad2 j)) (Some d1) l). - + intros. - rewrite eqq, H13, H14 in IHl. - unfold lift, lift2 in IHl; simpl in IHl. - f_equal. - subst v. - rewrite (split_subvar d d2 val3 v3); trivial. - rewrite (split_subvar d0 d4 val1 v1); trivial. - rewrite (split_subvar d1 d3 val2 v2); trivial. - inversion IHl. - rewrite H16. - specialize (H0 a grad_env1 grad_env2 grad_env3). - cut_to H0; try congruence. - rewrite H1,H2,H3 in H0. - rewrite H5,H7,H9 in H0. - unfold lift, lift2 in H0. - inversion H0. - rewrite H17; lra. - + intros. - assert (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (vecdf j) env (grad2 j)) (Some d1) l <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - apply H. - tauto. - + intros. - assert (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (vecdf j) env (grad1 j)) (Some d0) l <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - apply H. - tauto. - + assert (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (vecdf j) env (grad1 j + grad2 j)%R) (Some d) l <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - apply H. - tauto. - Qed. - - Lemma backprop_mat_grad_sum_list_env_iter {m n} (σ:df_env) - (df : DefinedFunction UnitAnn (DTVector n)) (s: SubVar) - (grad_env1 grad_env2 grad_env3:df_env) - (grad1 grad2 : (Matrix float m n)) - (val1 val2 val3 : float) - (l : list {m' | m' < m} ) - : - let v := (s, DTfloat) in - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> - (forall (i:{m' | m' < m}) (env1 env2 env3 : df_env), - vartlookup env1 v <> None -> - vartlookup env2 v <> None -> - vartlookup env3 v <> None -> - - df_eval_backprop_delta σ df v env3 (fun j => (grad1 i j + grad2 i j)%R) = - lift2 dfti_gen_plus - (df_eval_backprop_delta σ df v env1 (grad1 i)) - (df_eval_backprop_delta σ df v env2 (grad2 i))) -> - - vartlookup grad_env1 v = Some val1 -> - vartlookup grad_env2 v = Some val2 -> - vartlookup grad_env3 v = Some val3 -> - - lift (fun e : df_env => subvar v e val3) - (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ df env (fun k => (grad1 j k + grad2 j k)%R) ) - (Some grad_env3) l) = - lift2 dfti_gen_plus - (lift (fun e : df_env => subvar v e val1) - (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ df env (grad1 j)) (Some grad_env1) l)) - (lift (fun e : df_env => subvar v e val2) - (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ df env (grad2 j)) (Some grad_env2) l)). - Proof. - intros. - revert val1 val2 val3 grad_env1 grad_env2 grad_env3 H1 H2 H3. - induction l. - - intros. - simpl; f_equal. - unfold subvar; simpl. - rewrite H1,H2,H3. - lra. - - intros. - unfold df_eval_backprop_delta in *. - simpl. - unfold lift, lift2. - assert (df_eval_backprop_deriv σ df grad_env3 - (fun k : {n' : nat | n' < n} => (grad1 a k + grad2 a k)%R) <> None). - apply backprop_deriv_fully_closed_not_none. - apply H. - case_eq (df_eval_backprop_deriv σ df grad_env3 - (fun k : {n' : nat | n' < n} => (grad1 a k + grad2 a k)%R)) - ; [intros | tauto]. - assert (df_eval_backprop_deriv σ df grad_env1 (grad1 a)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H. - case_eq (df_eval_backprop_deriv σ df grad_env1 (grad1 a)%R) - ; [intros | tauto]. - assert (df_eval_backprop_deriv σ df grad_env2 (grad2 a)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H. - case_eq (df_eval_backprop_deriv σ df grad_env2 (grad2 a)%R) - ; [intros | tauto]. - - assert (vartlookup d v <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 v); congruence. - assert (vartlookup d0 v <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H7 v); congruence. - assert (vartlookup d1 v <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H9 v); congruence. - - case_eq (vartlookup d v); [intros v3 eq3 |tauto]. - case_eq (vartlookup d0 v); [intros v1 eq1 |tauto]. - case_eq (vartlookup d1 v); [intros v2 eq2 |tauto]. - - specialize (IHl v1 v2 v3 d0 d1 d eq1 eq2 eq3). - match_option. - case_eq (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ df env (grad1 j)) (Some d0) l). - case_eq (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ df env (grad2 j)) (Some d1) l). - + intros. - rewrite eqq, H13, H14 in IHl. - unfold lift, lift2 in IHl; simpl in IHl. - f_equal. - subst v. - rewrite (split_subvar d d2 val3 v3); trivial. - rewrite (split_subvar d0 d4 val1 v1); trivial. - rewrite (split_subvar d1 d3 val2 v2); trivial. - inversion IHl. - rewrite H16. - specialize (H0 a grad_env1 grad_env2 grad_env3). - cut_to H0; try congruence. - rewrite H1,H2,H3 in H0. - rewrite H5,H7,H9 in H0. - unfold lift, lift2 in H0. - inversion H0. - rewrite H17; lra. - + intros. - assert (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ df env (grad2 j)) (Some d1) l <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - apply H. - tauto. - + intros. - assert (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ df env (grad1 j)) (Some d0) l <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - apply H. - tauto. - + assert (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ df env - (fun k : {n' : nat | n' < n} => (grad1 j k + grad2 j k)%R)) - (Some d) l <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - apply H. - tauto. - Qed. - - Lemma matrix_zip_m_n {T} {m n} {i j} {m1 m2 : Matrix T m n} : - matrix_zip m1 m2 i j = (m1 i j, m2 i j). - Proof. - unfold matrix_zip. - rewrite vmap_nth; simpl. - now unfold vector_zip. - Qed. - - Lemma backprop_grad_sum {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (s: SubVar) - (grad_env1 grad_env2 grad_env3:df_env) - (grad1 grad2 : definition_function_types_interp T) : - let v := (s, DTfloat) in - vartlookup grad_env1 v <> None -> - vartlookup grad_env2 v <> None -> - vartlookup grad_env3 v <> None -> - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> - df_eval_backprop_delta σ df (s,DTfloat) grad_env3 (dfti_gen_plus grad1 grad2) = - lift2 dfti_gen_plus - (df_eval_backprop_delta σ df (s,DTfloat) grad_env1 grad1) - (df_eval_backprop_delta σ df (s,DTfloat) grad_env2 grad2). - Proof. - unfold df_eval_backprop_delta. - revert grad_env1 grad_env2 grad_env3. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case -(* - DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case -*) - ; simpl; intros. - - Case "Number"%string. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - unfold subvar; simpl. - rewrite eqq, eqq0, eqq1. - f_equal; lra. - - Case "Constant"%string. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - unfold subvar; simpl. - rewrite eqq, eqq0, eqq1. - f_equal; lra. - - Case "DVector"%string. - unfold two_vector_env_iter_alt in *. - rewrite vforall_forall in H3. - revert grad_env1 grad_env2 grad_env3 H0 H1 H2. - induction (bounded_seq0 n). - + intros. - simpl. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - unfold subvar; simpl. - rewrite eqq, eqq0, eqq1. - f_equal; lra. - + intros. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - unfold lift; simpl in *. - assert (df_eval_backprop_deriv σ (x a) grad_env3 (grad1 a + grad2 a)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - assert (df_eval_backprop_deriv σ (x a) grad_env1 (grad1 a)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - assert (df_eval_backprop_deriv σ (x a) grad_env2 (grad2 a)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - case_eq (df_eval_backprop_deriv σ (x a) grad_env3 (grad1 a + grad2 a)%R) - ; [intros|tauto]. - case_eq (df_eval_backprop_deriv σ (x a) grad_env1 (grad1 a)%R) - ; [intros|tauto]. - case_eq (df_eval_backprop_deriv σ (x a) grad_env2 (grad2 a)%R) - ; [intros|tauto]. - assert (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (grad1 i + grad2 i)%R) - (Some d2) l <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - apply (H3 a0). - match_option; [|tauto]. - assert (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (grad1 i)) (Some d3) - l <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - apply (H3 a0). - match_option; [|tauto]. - assert (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (grad2 i)) (Some d4) - l <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - apply (H3 a0). - match_option; [|tauto]. - unfold lift2; simpl. - - assert (vartlookup d2 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H7 (s, DTfloat) H2). - assert (vartlookup d3 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H8 (s, DTfloat) H0). - assert (vartlookup d4 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H9 (s, DTfloat) H1). - - case_eq (vartlookup d2 (s, DTfloat)); [intros|tauto]. - case_eq (vartlookup d3 (s, DTfloat)); [intros|tauto]. - case_eq (vartlookup d4 (s, DTfloat)); [intros|tauto]. - - rewrite (split_subvar d2 d5 d d8); trivial. - rewrite (split_subvar d3 d6 d0 d9); trivial. - rewrite (split_subvar d4 d7 d1 d10); trivial. - - f_equal. - - specialize (IHl d3 d4 d2 H14 H15 H13). - rewrite H16, H17, H18 in IHl. - rewrite eqq2, eqq3, eqq4 in IHl. - unfold lift, lift2 in IHl. - inversion IHl. - rewrite H20. - - specialize (H a (grad1 a) (grad2 a) grad_env1 grad_env2 grad_env3 H0 H1 H2). - specialize (H (H3 a)). - - rewrite eqq,eqq0,eqq1 in H. - rewrite H7, H8, H9 in H. - unfold lift, lift2 in H. - inversion H. - rewrite H21. - lra. - - Case "DMatrix"%string. - unfold two_matrix_env_iter_alt in *. - rewrite vforall_forall in H3. - revert grad_env1 grad_env2 grad_env3 H0 H1 H2. - induction (bounded_seq0 n); induction (bounded_seq0 m). - + intros. - simpl. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - unfold subvar; simpl. - rewrite eqq, eqq0, eqq1. - f_equal; lra. - + intros. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - simpl. - unfold lift, lift2; simpl; f_equal. - unfold subvar; simpl. - rewrite eqq, eqq0, eqq1; simpl; lra. - + intros. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - rewrite list_env_iter_id. - rewrite list_env_iter_id. - rewrite list_env_iter_id. - unfold subvar; simpl. - rewrite eqq, eqq0, eqq1; simpl; f_equal; lra. - + intros. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - simpl. - assert (df_eval_backprop_deriv σ (x a a0) grad_env3 (grad1 a a0 + grad2 a a0)%R <> None). - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a); rewrite vforall_forall in H3; apply H3. - assert (df_eval_backprop_deriv σ (x a a0) grad_env1 (grad1 a a0)%R <> None). - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a); rewrite vforall_forall in H3; apply H3. - assert (df_eval_backprop_deriv σ (x a a0) grad_env2 (grad2 a a0)%R <> None). - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a); rewrite vforall_forall in H3; apply H3. - case_eq (df_eval_backprop_deriv σ (x a a0) grad_env3 (grad1 a a0 + grad2 a a0)%R) - ; [intros|tauto]. - case_eq (df_eval_backprop_deriv σ (x a a0) grad_env1 (grad1 a a0)%R) - ; [intros|tauto]. - case_eq (df_eval_backprop_deriv σ (x a a0) grad_env2 (grad2 a a0)%R) - ; [intros|tauto]. - assert (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (x a j) env (grad1 a j + grad2 a j)%R) - (Some d2) l0 <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a); rewrite vforall_forall in H3; apply H3. - case_eq (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (x a j) env (grad1 a j + grad2 a j)%R) - (Some d2) l0 ); [intros | tauto]. - assert (list_env_iter - (fun (i : {n' : nat | n' < m}) (env : df_env) => - df_eval_backprop_deriv σ (x a i) env (grad1 a i)) (Some d3) - l0 <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a); rewrite vforall_forall in H3; apply H3. - case_eq (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (x a j) env (grad1 a j)%R) - (Some d3) l0 ); [intros | tauto]. - assert (list_env_iter - (fun (i : {n' : nat | n' < m}) (env : df_env) => - df_eval_backprop_deriv σ (x a i) env (grad2 a i)) (Some d4) - l0 <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a); rewrite vforall_forall in H3; apply H3. - case_eq (list_env_iter - (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (x a j) env (grad2 a j)%R) - (Some d4) l0 ); [intros | tauto]. - assert - (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad1 i j + grad2 i j)%R) - (df_eval_backprop_deriv σ (x i a0) env (grad1 i a0 + grad2 i a0)%R) l0) - (Some d5) l <> None). - apply list_env_iter_total_fun; intros. - assert (df_eval_backprop_deriv σ (x a1 a0) env0 (grad1 a1 a0 + grad2 a1 a0)%R <> None). - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a1); rewrite vforall_forall in H3; apply H3. - case_eq (df_eval_backprop_deriv σ (x a1 a0) env0 (grad1 a1 a0 + grad2 a1 a0)%R); [intros|tauto]. - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a1); rewrite vforall_forall in H3; apply H3. - unfold lift; simpl. - match_option; [|tauto]. - - assert - (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad1 i j)%R) - (df_eval_backprop_deriv σ (x i a0) env (grad1 i a0)%R) l0) - (Some d6) l <> None). - apply list_env_iter_total_fun; intros. - assert (df_eval_backprop_deriv σ (x a1 a0) env0 (grad1 a1 a0)%R <> None). - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a1); rewrite vforall_forall in H3; apply H3. - case_eq (df_eval_backprop_deriv σ (x a1 a0) env0 (grad1 a1 a0)%R); [intros|tauto]. - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a1); rewrite vforall_forall in H3; apply H3. - match_option; [|tauto]. - - assert - (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (grad2 i j)%R) - (df_eval_backprop_deriv σ (x i a0) env (grad2 i a0)%R) l0) - (Some d7) l <> None). - apply list_env_iter_total_fun; intros. - assert (df_eval_backprop_deriv σ (x a1 a0) env0 (grad2 a1 a0)%R <> None). - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a1); rewrite vforall_forall in H3; apply H3. - case_eq (df_eval_backprop_deriv σ (x a1 a0) env0 (grad2 a1 a0)%R); [intros|tauto]. - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - specialize (H3 a1); rewrite vforall_forall in H3; apply H3. - match_option; [|tauto]. - - unfold lift2; simpl; f_equal. - - assert (vartlookup d2 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H7 (s, DTfloat) H2). - assert (vartlookup d3 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H8 (s, DTfloat) H0). - assert (vartlookup d4 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none H9 (s, DTfloat) H1). - - case_eq (vartlookup d2 (s, DTfloat)); [intros|tauto]. - case_eq (vartlookup d3 (s, DTfloat)); [intros|tauto]. - case_eq (vartlookup d4 (s, DTfloat)); [intros|tauto]. - - assert (Hc := H). - assert (H3c := H3). - - specialize (H a a0 (grad1 a a0) (grad2 a a0) grad_env1 grad_env2 grad_env3 - H0 H1 H2). - - specialize (H3 a). - rewrite vforall_forall in H3. - specialize (H (H3 a0)). - rewrite eqq, eqq0, eqq1 in H; simpl in H. - rewrite H7, H8, H9 in H; unfold lift, lift2 in H; simpl in H. - - assert (vartlookup d5 (s, DTfloat) <> None). - apply (vartlookup_list_env_iter - s (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (x a j) env (grad1 a j + grad2 a j)%R) - l0 d2 d5); trivial; intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H25 (s, DTfloat) H26). - case_eq (vartlookup d5 (s, DTfloat)); [intros|tauto]. - assert (vartlookup d6 (s, DTfloat) <> None). - apply (vartlookup_list_env_iter - s (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (x a j) env (grad1 a j)%R) - l0 d3 d6); trivial; intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H27 (s, DTfloat) H28). - case_eq (vartlookup d6 (s, DTfloat)); [intros|tauto]. - assert (vartlookup d7 (s, DTfloat) <> None). - apply (vartlookup_list_env_iter - s (fun (j : {m' : nat | m' < m}) (env : df_env) => - df_eval_backprop_deriv σ (x a j) env (grad2 a j)%R) - l0 d4 d7); trivial; intros. - apply (df_eval_backprop_deriv_preserves_lookup_not_none H29 (s, DTfloat) H30). - case_eq (vartlookup d7 (s, DTfloat)); [intros|tauto]. - - specialize (IHl d6 d7 d5 H27 H29 H25). - rewrite H28, H30, H26 in IHl; simpl in IHl. - - rewrite eqq2,eqq3,eqq4 in IHl. - unfold lift, lift2 in IHl; simpl in IHl. - - rewrite (split_subvar d5 d8 d d14); trivial. - rewrite (split_subvar d6 d9 d0 d15); trivial. - rewrite (split_subvar d7 d10 d1 d16); trivial. - - rewrite (split_subvar d2 d5 d d11); trivial. - rewrite (split_subvar d3 d6 d0 d12); trivial. - rewrite (split_subvar d4 d7 d1 d13); trivial. - - inversion H. - inversion IHl. - rewrite H32, H33. - - generalize (backprop_grad_sum_list_env_iter - σ (x a) s d3 d4 d2 (grad1 a) (grad2 a) - d12 d13 d11 l0); intros. - specialize (H31 H3). - cut_to H31. - * rewrite H11,H13,H15 in H31. - unfold lift, lift2 in H31. - inversion H31. - rewrite H35; lra. - * intros. - unfold df_eval_backprop_delta. - specialize (Hc a i (grad1 a i) (grad2 a i) env1 env2 env3). - specialize (Hc H34 H35 H36). - specialize (Hc (H3 i)). - apply Hc. - * trivial. - * trivial. - * trivial. - - Case "Var"%string. - match_option; [|tauto]. - case_eq (vartlookup grad_env1 (s, DTfloat)); [intros| tauto]. - case_eq (vartlookup grad_env2 (s, DTfloat)); [intros| tauto]. - destruct (v == (s,DTfloat)). - + invcs e. - rewrite eqq, H3, H4; simpl. - f_equal. - rewrite subvar_addvar_scalar_eq; trivial. - rewrite subvar_addvar_scalar_eq; trivial. - rewrite subvar_addvar_scalar_eq; trivial. - + assert (v<> (s, DTfloat)) by congruence. - match_option. - * match_option. - -- match_option; unfold lift, lift2; simpl; f_equal. - ++ rewrite subvar_addvar_scalar_neq; trivial. - rewrite subvar_addvar_scalar_neq; trivial. - rewrite subvar_addvar_scalar_neq; trivial. - lra. - ++ rewrite subvar_addvar_scalar_neq; trivial. - rewrite subvar_addvar_scalar_neq; trivial. - unfold subvar; simpl; rewrite H4; lra. - -- unfold lift, lift2; simpl; f_equal. - rewrite subvar_addvar_scalar_neq; trivial. - case_eq (vartlookup grad_env2 v); intros. - ++ rewrite subvar_addvar_scalar_neq; trivial. - f_equal; unfold subvar; simpl. - rewrite H3; lra. - ++ unfold subvar; simpl. - rewrite H3, H4. - f_equal; lra. - * match_option. - -- match_option; unfold lift, lift2; simpl; f_equal. - ++ rewrite subvar_addvar_scalar_neq; trivial. - rewrite subvar_addvar_scalar_neq; trivial. - unfold subvar; simpl. - rewrite eqq; lra. - ++ rewrite subvar_addvar_scalar_neq; trivial. - unfold subvar; simpl. - rewrite eqq, H4; lra. - -- match_option; unfold lift, lift2; simpl; f_equal. - ++ rewrite subvar_addvar_scalar_neq; trivial. - unfold subvar; simpl. - rewrite eqq, H3; lra. - ++ unfold subvar; simpl. - rewrite eqq, H3, H4; lra. - - Case "Plus"%string. - destruct H2. - unfold lift, lift2. - repeat simpl_closed_backprop. - simpler2. - specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 grad1 grad2 env1 env3 env). - simpl in IHdf1. - rewrite eqq5, eqq6, eqq7 in IHdf1. - rewrite eqq, eqq1, eqq3 in IHdf1. - unfold lift, lift2 in IHdf1; simpl in IHdf1. - inversion IHdf1. - cut_to IHdf2; simpler2. - simpl in IHdf2. - rewrite eqq0, eqq2, eqq4 in IHdf2. - unfold lift, lift2 in IHdf2. - inversion IHdf2. - f_equal. - rewrite (split_subvar env env0 val val2); trivial. - rewrite (split_subvar env1 env2 val0 val3); trivial. - rewrite (split_subvar env3 env4 val1 val4); trivial. - rewrite H5, H6; lra. - - Case "Minus"%string. - destruct H2. - unfold lift, lift2. - repeat simpl_closed_backprop. - simpler2. - specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 (-grad1)%R (-grad2)%R env1 env3 env). - simpl in IHdf1. - rewrite eqq5, eqq6, eqq7 in IHdf1. - rewrite eqq, eqq1, eqq3 in IHdf1. - unfold lift, lift2 in IHdf1; simpl in IHdf1. - inversion IHdf1. - cut_to IHdf2; simpler2. - simpl in IHdf2. - replace (- grad1 + - grad2)%R with (- (grad1 + grad2))%R in IHdf2 by lra. - rewrite eqq0, eqq2, eqq4 in IHdf2. - unfold lift, lift2 in IHdf2; simpl in IHdf2. - inversion IHdf2. - rewrite (split_subvar env env0 val val2); trivial. - rewrite (split_subvar env1 env2 val0 val3); trivial. - rewrite (split_subvar env3 env4 val1 val4); trivial. - f_equal; rewrite H5, H6; lra. - - Case "Times"%string. - destruct H2. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df1); intros. - specialize (H4 H2). - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H5 H3). - match_option; [|tauto]. - unfold lift, lift2. - repeat simpl_closed_backprop. - simpler2. - specialize (IHdf1 (d1 * grad1)%R (d1 * grad2)%R grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 (d0 * grad1)%R (d0 * grad2)%R env1 env3 env). - rewrite eqq8, eqq9, eqq in IHdf1. - simpl in IHdf1. - replace (d1 *grad1 + d1*grad2)%R with (d1 * (grad1 + grad2))%R in IHdf1 by lra. - rewrite eqq2, eqq4, eqq6 in IHdf1. - unfold lift, lift2 in IHdf1; inversion IHdf1. - cut_to IHdf2; simpler2. - simpl in IHdf2. - replace (d0 *grad1 + d0 *grad2)%R with (d0 *(grad1 + grad2))%R in IHdf2 by lra. - rewrite eqq3,eqq5,eqq7 in IHdf2. - unfold lift, lift2 in IHdf2; inversion IHdf2. - f_equal. - rewrite (split_subvar env env0 d val1); trivial. - rewrite (split_subvar env1 env2 val val2); trivial. - rewrite (split_subvar env3 env4 val0 val3); trivial. - rewrite H7, H8; lra. - - Case "Divide"%string. - destruct H2. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df1); intros. - specialize (H4 H2). - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H5 H3). - match_option; [|tauto]. - unfold lift, lift2. - repeat simpl_closed_backprop. - simpler2. - specialize (IHdf1 (grad1/d1)%R (grad2/d1)%R grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 (-d0/(d1*d1) * grad1)%R (-d0/(d1*d1) * grad2)%R env1 env3 env). - rewrite eqq, eqq8, eqq9 in IHdf1. - simpl in IHdf1. - replace (grad1/d1 + grad2/d1)%R with ((grad1 + grad2)/d1)%R in IHdf1 by lra. - rewrite eqq2, eqq4, eqq6 in IHdf1. - unfold lift, lift2 in IHdf1; inversion IHdf1. - cut_to IHdf2; simpler2. - simpl in IHdf2. - replace (-d0/(d1*d1) *grad1 + -d0/(d1*d1) *grad2)%R with (-d0/(d1*d1) *(grad1 + grad2))%R in IHdf2 by lra. - rewrite eqq3, eqq5, eqq7 in IHdf2. - unfold lift, lift2 in IHdf2; inversion IHdf2. - f_equal. - rewrite (split_subvar env env0 d val1); trivial. - rewrite (split_subvar env1 env2 val val2); trivial. - rewrite (split_subvar env3 env4 val0 val3); trivial. - rewrite H7, H8; lra. - - Case "Square"%string. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df); intros. - specialize (H3 H2). - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (2 * d0 *grad1)%R (2 * d0 *grad2)%R grad_env1 grad_env2 grad_env3). - specialize (IHdf H H0 H1 H2). - rewrite eqq, eqq1, eqq2 in IHdf. - simpl in IHdf. - replace (2 * d0 * grad1 + 2 * d0 * grad2)%R with (2 * d0 * (grad1 + grad2))%R in IHdf by lra. - apply IHdf. - - Case "Exp"%string. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df); intros. - specialize (H3 H2). - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (grad1 * exp d0)%R (grad2*exp d0)%R grad_env1 grad_env2 grad_env3). - specialize (IHdf H H0 H1 H2). - rewrite eqq, eqq1, eqq2 in IHdf. - simpl in IHdf. - replace (grad1 * exp d0 + grad2 * exp d0)%R with ((grad1 + grad2) * exp d0)%R in IHdf by lra. - apply IHdf. - - Case "Log"%string. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df); intros. - specialize (H3 H2). - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (grad1/d0)%R (grad2/d0)%R grad_env1 grad_env2 grad_env3). - specialize (IHdf H H0 H1 H2). - rewrite eqq, eqq1, eqq2 in IHdf. - simpl in IHdf. - replace (grad1/d0 + grad2/d0)%R with ((grad1 + grad2) / d0)%R in IHdf by lra. - apply IHdf. - - Case "Abs"%string. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df); intros. - specialize (H3 H2). - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (grad1*sign d0)%R (grad2*sign d0)%R grad_env1 grad_env2 grad_env3). - specialize (IHdf H H0 H1 H2). - rewrite eqq, eqq1, eqq2 in IHdf. - simpl in IHdf. - replace (grad1*sign d0 + grad2*sign d0)%R with ((grad1 + grad2) * sign d0)%R in IHdf by lra. - apply IHdf. - - Case "Sign"%string. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (0)%R (0)%R grad_env1 grad_env2 grad_env3). - specialize (IHdf H H0 H1 H2). - rewrite eqq, eqq0, eqq1 in IHdf. - simpl in IHdf. - replace (0 + 0)%R with 0%R in IHdf by lra. - apply IHdf. - - Case "PSign"%string. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (0)%R (0)%R grad_env1 grad_env2 grad_env3). - specialize (IHdf H H0 H1 H2). - rewrite eqq, eqq0, eqq1 in IHdf. - simpl in IHdf. - replace (0 + 0)%R with 0%R in IHdf by lra. - apply IHdf. - - Case "Max"%string. - destruct H2. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df1); intros. - specialize (H4 H2). - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H5 H3). - match_option; [|tauto]. - destruct (Rle_dec d0 d1). - + match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf2 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H3). - rewrite eqq, eqq2, eqq3 in IHdf2; simpl in IHdf2. - apply IHdf2. - + match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2). - rewrite eqq, eqq2, eqq3 in IHdf1; simpl in IHdf1. - apply IHdf1. - - Case "VectorDot"%string. - destruct H2. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df1); intros. - specialize (H4 H2). - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H5 H3). - match_option; [|tauto]. - unfold lift, lift2. - repeat simpl_closed_backprop. - simpler2. - specialize (IHdf1 - (vmap (fun rv : R => (rv * grad1)%R) d1) - (vmap (fun rv : R => (rv * grad2)%R) d1) - grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 - (vmap (fun lv : R => (lv * grad1)%R) d0) - (vmap (fun lv : R => (lv * grad2)%R) d0) - env1 env3 env). - rewrite eqq, eqq8, eqq9 in IHdf1. - simpl in IHdf1. - replace - (fun i : {n' : nat | n' < n} => - (vmap (fun rv : R => rv * grad1) d1 i + vmap (fun rv : R => rv * grad2) d1 i)%R) - with - (vmap (fun rv : R => (rv * (grad1 + grad2))%R) d1) - in IHdf1. - + rewrite eqq2, eqq4, eqq6 in IHdf1. - unfold lift, lift2 in IHdf1; inversion IHdf1. - simpl in IHdf2. - cut_to IHdf2; simpler2. - replace (fun i : {n' : nat | n' < n} => - (vmap (fun lv : R => lv * grad1) d0 i + - vmap (fun lv : R => lv * grad2) d0 i)%R) with - (vmap (fun lv : R => (lv * (grad1 + grad2))%R) d0) in IHdf2. - * rewrite eqq3, eqq5, eqq7 in IHdf2. - unfold lift, lift2 in IHdf2; inversion IHdf2. - f_equal. - rewrite (split_subvar env env0 d val1); trivial. - rewrite (split_subvar env1 env2 val val2); trivial. - rewrite (split_subvar env3 env4 val0 val3); trivial. - rewrite H7, H8; lra. - * apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vmap_nth. - rewrite vmap_nth. - rewrite vmap_nth. - lra. - + apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vmap_nth. - rewrite vmap_nth. - rewrite vmap_nth. - lra. - - Case "VectorSum"%string. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (ConstVector n grad1) (ConstVector n grad2) grad_env1 grad_env2 grad_env3). - specialize (IHdf H H0 H1 H2). - rewrite eqq, eqq0, eqq1 in IHdf. - simpl in IHdf. - unfold ConstVector in IHdf. - unfold ConstVector. - apply IHdf. - - Case "MatrixSum"%string. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (ConstMatrix m n grad1) (ConstMatrix m n grad2) grad_env1 grad_env2 grad_env3). - specialize (IHdf H H0 H1 H2). - rewrite eqq, eqq0, eqq1 in IHdf. - simpl in IHdf. - unfold ConstMatrix in IHdf. - unfold ConstMatrix. - apply IHdf. - - Case "VectorElem"%string. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf (fun k : {n' : nat | n' < n} => - if equiv_dec (` k) (` i) then grad1 else 0%R) - (fun k : {n' : nat | n' < n} => - if equiv_dec (` k) (` i) then grad2 else 0%R) - grad_env1 grad_env2 grad_env3 H H0 H1 H2). - rewrite eqq, eqq0, eqq1 in IHdf. - simpl in IHdf. - replace (fun k : {n' : nat | n' < n} => - if equiv_dec (` k) (` i) then (grad1 + grad2)%R else 0%R) with - (fun i0 : {n' : nat | n' < n} => - ((if equiv_dec (` i0) (` i) then grad1 else 0) + - (if equiv_dec (` i0) (` i) then grad2 else 0))%R). - apply IHdf. - apply FunctionalExtensionality.functional_extensionality; intros. - destruct (equiv_dec (` x) (` i)); lra. - - Case "MatrixElem"%string. - match_option; [|tauto]. - match_option; [|tauto]. - match_option; [|tauto]. - specialize (IHdf - (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) => - if equiv_dec (` k1) (` i) - then if equiv_dec (` k2) (` j) then grad1 else 0%R - else 0%R) - (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) => - if equiv_dec (` k1) (` i) - then if equiv_dec (` k2) (` j) then grad2 else 0%R - else 0%R) - grad_env1 grad_env2 grad_env3 H H0 H1 H2). - rewrite eqq, eqq0, eqq1 in IHdf. - simpl in IHdf. - replace (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) => - if equiv_dec (` k1) (` i) - then if equiv_dec (` k2) (` j) then (grad1 + grad2)%R else 0%R - else 0%R) with - (fun (i0 : {n' : nat | n' < m}) (j0 : {m' : nat | m' < n}) => - ((if equiv_dec (` i0) (` i) - then if equiv_dec (` j0) (` j) then grad1 else 0 - else 0) + - (if equiv_dec (` i0) (` i) - then if equiv_dec (` j0) (` j) then grad2 else 0 - else 0))%R). - apply IHdf. - apply FunctionalExtensionality.functional_extensionality; intros. - destruct (equiv_dec (` x) (` i)). - + apply FunctionalExtensionality.functional_extensionality; intros. - destruct (equiv_dec (` x0) (` j)); lra. - + apply FunctionalExtensionality.functional_extensionality; intros. - lra. - - Case "MatrixVectorMult"%string. - destruct H2. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df1); intros. - specialize (H4 H2). - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H5 H3). - match_option; [|tauto]. - unfold lift, lift2. - repeat simpl_closed_backprop. - simpler2. - specialize (IHdf1 - (fun (i : {n' : nat | n' < m}) - (j : {m' : nat | m' < n}) => (grad1 i * d1 j)%R) - (fun (i : {n' : nat | n' < m}) - (j : {m' : nat | m' < n}) => (grad2 i * d1 j)%R) - grad_env1 grad_env2 grad_env3 H H0 H1 H2). - rewrite eqq, eqq8, eqq9 in IHdf1. - specialize (IHdf2 - (matrix_vector_mult - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => d0 j i) grad1) - (matrix_vector_mult - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => d0 j i) grad2) - env1 env3 env). - simpl in IHdf1. - replace - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (grad1 i * d1 j + grad2 i * d1 j)%R) with - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - ((grad1 i + grad2 i) * d1 j)%R) in IHdf1. - + rewrite eqq2, eqq4, eqq6 in IHdf1. - unfold lift, lift2 in IHdf1; inversion IHdf1. - cut_to IHdf2; simpler2. - simpl in IHdf2. - replace - (fun i : {n' : nat | n' < n} => - (matrix_vector_mult - (fun (i0 : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) => - d0 j i0) grad1 i + - matrix_vector_mult - (fun (i0 : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) => - d0 j i0) grad2 i)%R) with - (matrix_vector_mult - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => d0 j i) - (fun i : {n' : nat | n' < m} => (grad1 i + grad2 i)%R)) in IHdf2. - * rewrite eqq3, eqq5, eqq7 in IHdf2. - unfold lift, lift2 in IHdf2; inversion IHdf2. - f_equal. - rewrite (split_subvar env env0 d val1); trivial. - rewrite (split_subvar env1 env2 val val2); trivial. - rewrite (split_subvar env3 env4 val0 val3); trivial. - rewrite H7, H8; lra. - * apply FunctionalExtensionality.functional_extensionality; intros. - unfold matrix_vector_mult. - rewrite vsum_plus; f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - simpl; lra. - + apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - - Case "MatrixVectorAdd"%string. - destruct H2. - simpl; intros. - repeat simpl_closed_backprop. - simpler2. - specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2). - simpl in IHdf1. - rewrite eqq,eqq0,eqq1,eqq2,eqq3,eqq4 in IHdf1. - unfold lift, lift2 in IHdf1. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) H); intro. - case_eq (vartlookup env0 (s, DTfloat)); [intros|tauto]. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat) H0); intro. - case_eq (vartlookup env1 (s, DTfloat)); [intros|tauto]. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq (s, DTfloat) H1); intro. - case_eq (vartlookup env (s, DTfloat)); [intros|tauto]. - generalize (backprop_mat_grad_sum_list_env_iter - σ df2 s env0 env1 env (transpose grad1) (transpose grad2) - d d0 d1 (bounded_seq0 n)); intros. - simpl in H10. - specialize (H10 H3). - cut_to H10; trivial. - + do 3 match_option. - * rewrite eqq6, eqq7 in H10. - replace (fun (j : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env - (fun k : {n' : nat | n' < m} => ((@transpose R m n grad1 j k) + - (@transpose R m n grad2 j k))%R)) - with - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env - (transpose - (fun (i0 : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (grad1 i0 j + grad2 i0 j)%R) i)) in H10. - -- rewrite eqq5 in H10. - unfold lift, lift2 in H10. - unfold lift, lift2; f_equal. - inversion IHdf1; inversion H10. - rewrite (split_subvar env d2 val d1); trivial. - rewrite (split_subvar env0 d3 val0 d); trivial. - rewrite (split_subvar env1 d4 val1 d0); trivial. - rewrite H12, H13; lra. - -- apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - f_equal. - * assert (list_env_iter - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad2 i)) - (Some env1) (bounded_seq0 n) <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - * assert (list_env_iter - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad1 i)) - (Some env0) (bounded_seq0 n) <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - * assert (list_env_iter - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env (transpose grad1 i)) - (Some env0) (bounded_seq0 n) <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - * assert (list_env_iter - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv σ df2 env - (transpose - (fun (i0 : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (grad1 i0 j + grad2 i0 j)%R) i)) (Some env) (bounded_seq0 n) <> None). - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - + intros. - apply IHdf2; trivial. - - Case "MatrixMult"%string. - destruct H2. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df1); intros. - specialize (H4 H2). - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H5 H3). - match_option; [|tauto]. - unfold lift, lift2. - repeat simpl_closed_backprop. - simpler2. - specialize - (IHdf1 - (matrix_mult grad1 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < p}) => d1 j i)) - (matrix_mult grad2 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < p}) => d1 j i)) - grad_env1 grad_env2 grad_env3 H H0 H1 H2). - rewrite eqq, eqq8, eqq9 in IHdf1. - specialize (IHdf2 - (matrix_mult (fun (i : {n' : nat | n' < p}) - (j : {m' : nat | m' < m}) => d0 j i) - grad1) - (matrix_mult (fun (i : {n' : nat | n' < p}) - (j : {m' : nat | m' < m}) => d0 j i) - grad2) - env1 env3 env). - simpl in IHdf1. - replace - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < p}) => - (matrix_mult grad1 - (fun (i0 : {n' : nat | (n' < n)%nat}) (j0 : {m' : nat | (m' < p)%nat}) => - d1 j0 i0) i j + - matrix_mult grad2 - (fun (i0 : {n' : nat | (n' < n)%nat}) (j0 : {m' : nat | (m' < p)%nat}) => - d1 j0 i0) i j)%R) - with - (matrix_mult - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (grad1 i j + grad2 i j)%R) - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < p}) => d1 j i)) in IHdf1. - + rewrite eqq2, eqq4, eqq6 in IHdf1. - unfold lift, lift2 in IHdf1; simpl in IHdf1. - inversion IHdf1. - cut_to IHdf2; simpler2. - simpl in IHdf2. - rewrite eqq5, eqq7 in IHdf2. - replace - (fun (i : {n' : nat | n' < p}) (j : {m' : nat | m' < n}) => - (matrix_mult - (fun (i0 : {n' : nat | (n' < p)%nat}) (j0 : {m' : nat | (m' < m)%nat}) => - d0 j0 i0) grad1 i j + - matrix_mult - (fun (i0 : {n' : nat | (n' < p)%nat}) (j0 : {m' : nat | (m' < m)%nat}) => - d0 j0 i0) grad2 i j)%R) with - (matrix_mult (fun (i : {n' : nat | n' < p}) (j : {m' : nat | m' < m}) => d0 j i) - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (grad1 i j + grad2 i j)%R)) - in IHdf2. - * rewrite eqq3 in IHdf2. - unfold lift, lift2 in IHdf2; simpl in IHdf2. - inversion IHdf2. - f_equal. - rewrite (split_subvar env env0 d val1); trivial. - rewrite (split_subvar env1 env2 val val2); trivial. - rewrite (split_subvar env3 env4 val0 val3); trivial. - rewrite H7, H8; lra. - * unfold matrix_mult. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vsum_plus; f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - simpl; lra. - + unfold matrix_mult. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vsum_plus; f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - simpl; lra. - - Case "VectorPlus"%string. - destruct H2. - unfold lift, lift2. - repeat simpl_closed_backprop. - specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 grad1 grad2 env1 env3 env). - simpler2. - simpl in IHdf1. - rewrite eqq1, eqq3, eqq in IHdf1. - unfold lift, lift2 in IHdf1; simpl in IHdf1. - inversion IHdf1. - cut_to IHdf2; try congruence. - simpl in IHdf2. - rewrite eqq0, eqq2, eqq4 in IHdf2. - unfold lift, lift2 in IHdf2; inversion IHdf2. - f_equal. - rewrite (split_subvar env env0 val val5); trivial. - rewrite (split_subvar env1 env2 val0 val6); trivial. - rewrite (split_subvar env3 env4 val1 val7); trivial. - rewrite eqq5 in eqq8. - rewrite eqq6 in eqq9. - rewrite eqq7 in eqq10. - invcs eqq8; invcs eqq9; invcs eqq10. - rewrite H5, H6; lra. - - Case "VectorMinus"%string. - destruct H2. - unfold lift, lift2. - repeat simpl_closed_backprop. - specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 (fun i : {n' : nat | n' < n} => (- grad1 i)%R) - (fun i : {n' : nat | n' < n} => (- grad2 i)%R) - env1 env3 env). - simpler2. - simpl in IHdf1. - rewrite eqq, eqq1, eqq3 in IHdf1. - unfold lift, lift2 in IHdf1; simpl in IHdf1. - inversion IHdf1. - cut_to IHdf2; try congruence. - simpl in IHdf2. - replace (fun i : {n' : nat | n' < n} => (- grad1 i + - grad2 i)%R) with - (fun i : {n' : nat | n' < n} => (- (grad1 i + grad2 i))%R) in IHdf2. - rewrite eqq0, eqq2, eqq4 in IHdf2. - unfold lift, lift2 in IHdf2; simpl in IHdf2. - inversion IHdf2. - f_equal. - rewrite (split_subvar env env0 val val5); trivial. - rewrite (split_subvar env1 env2 val0 val6); trivial. - rewrite (split_subvar env3 env4 val1 val7); trivial. - rewrite eqq5 in eqq8; rewrite eqq6 in eqq9; rewrite eqq7 in eqq10. - invcs eqq8; invcs eqq9; invcs eqq10. - rewrite H5, H6; lra. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - - Case "MatrixPlus"%string. - destruct H2. - match_option; [|tauto]. - assert (df_eval_backprop_deriv - σ df1 grad_env3 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (grad1 i j + grad2 i j)%R) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env1 (grad1)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 (grad2)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 grad1 grad2 d2 d4 d0). - rewrite eqq, eqq1, eqq3 in IHdf1. - simpl in IHdf1. - rewrite eqq0, eqq2, eqq4 in IHdf1. - unfold lift, lift2 in IHdf1; simpl in IHdf1. - inversion IHdf1. - assert (vartlookup d2 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) H). - assert (vartlookup d4 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s, DTfloat) H0). - assert (vartlookup d0 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) H1). - specialize (IHdf2 H7 H9 H10 H3). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - simpl in IHdf2. - unfold lift, lift2; simpl. - assert (df_eval_backprop_deriv - σ df2 d0 - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad1 i j + grad2 i j)%R) <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d2 (grad1)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - case_eq (df_eval_backprop_deriv σ df2 d2 (grad1)%R); [intros|tauto]. - assert (df_eval_backprop_deriv σ df2 d4 (grad2)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - case_eq (df_eval_backprop_deriv σ df2 d4 (grad2)%R); [intros|tauto]. - simpl in IHdf2. - rewrite eqq8, H13, H15 in IHdf2. - unfold lift, lift2 in IHdf2; simpl in IHdf2. - inversion IHdf2. - f_equal. - rewrite (split_subvar d0 d8 d d5); trivial. - rewrite (split_subvar d2 d9 d1 d6); trivial. - rewrite (split_subvar d4 d10 d3 d7); trivial. - rewrite H8, H17; lra. - - Case "MatrixMinus"%string. - destruct H2. - match_option; [|tauto]. - assert (df_eval_backprop_deriv - σ df1 grad_env3 - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad1 i j + grad2 i j)%R) <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env1 (grad1)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 (grad2)%R <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (- grad1 i j)%R) - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (- grad2 i j)%R) - d2 d4 d0). - rewrite eqq, eqq1, eqq3 in IHdf1. - simpl in IHdf1. - rewrite eqq0, eqq2, eqq4 in IHdf1. - unfold lift, lift2 in IHdf1; simpl in IHdf1. - inversion IHdf1. - assert (vartlookup d2 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) H). - assert (vartlookup d4 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s, DTfloat) H0). - assert (vartlookup d0 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) H1). - specialize (IHdf2 H7 H9 H10 H3). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - unfold lift, lift2; simpl. - assert (df_eval_backprop_deriv - σ df2 d0 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (- (grad1 i j + grad2 i j))%R) <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - match_option; [|tauto]. - assert (df_eval_backprop_deriv - σ df2 d2 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad1 i j)%R) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - case_eq (df_eval_backprop_deriv σ df2 d2 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad1 i j)%R)) - ; [intros |tauto]. - assert (df_eval_backprop_deriv σ df2 d4 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad2 i j)%R) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - case_eq (df_eval_backprop_deriv σ df2 d4 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad2 i j)%R)) - ; [intros | tauto]. - simpl in IHdf2. - replace - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (- grad1 i j + - grad2 i j)%R) with - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (- (grad1 i j + grad2 i j))%R) in IHdf2. - rewrite eqq8, H13, H15 in IHdf2. - unfold lift, lift2 in IHdf2; simpl in IHdf2. - inversion IHdf2. - f_equal. - rewrite (split_subvar d0 d8 d d5); trivial. - rewrite (split_subvar d2 d9 d1 d6); trivial. - rewrite (split_subvar d4 d10 d3 d7); trivial. - rewrite H8, H17; lra. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - - Case "VectorScalMult"%string. - destruct H2. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df1); intros. - specialize (H4 H2). - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H5 H3). - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env3 - (vsum (fun j : {n' : nat | n' < n} => (d1 j * (grad1 j + grad2 j))%R)) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env1 - (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad1 j)%R)) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 - (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad2 j)%R)) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - specialize (IHdf1 - (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad1 j)%R)) - (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad2 j)%R)) - grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 - (fun j : {n' : nat | n' < n} => (d0 * grad1 j)%R) - (fun j : {n' : nat | n' < n} => (d0 * grad2 j)%R) - d4 d6 d2). - rewrite eqq, eqq3, eqq5 in IHdf1. - simpl in IHdf1. - replace - (@vsum floatish_R n (fun j : {n' : nat | (n' < n)%nat} => d1 j * grad1 j) + - @vsum floatish_R n (fun j : {n' : nat | (n' < n)%nat} => d1 j * grad2 j))%R with - (vsum (fun j : {n' : nat | n' < n} => (d1 j * (grad1 j + grad2 j))%R)) in IHdf1. - + rewrite eqq2, eqq4, eqq6 in IHdf1. - unfold lift, lift2 in IHdf1; simpl in IHdf1. - inversion IHdf1. - assert (vartlookup d2 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) H1). - assert (vartlookup d4 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s, DTfloat) H). - assert (vartlookup d6 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq6 (s, DTfloat) H0). - specialize (IHdf2 H11 H12 H9 H3). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - unfold lift, lift2; simpl. - assert (df_eval_backprop_deriv - σ df2 d2 - (fun j : {n' : nat | n' < n} => (d0 * (grad1 j + grad2 j))%R) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d4 - (fun j : {n' : nat | n' < n} => (d0 * grad1 j)%R) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - case_eq (df_eval_backprop_deriv σ df2 d4 - (fun j : {n' : nat | n' < n} => (d0 * grad1 j)%R)) - ; [intros | tauto]. - assert (df_eval_backprop_deriv σ df2 d6 - (fun j : {n' : nat | n' < n} => (d0 * grad2 j)%R) - <> None ). - apply backprop_deriv_fully_closed_not_none. - apply H3. - case_eq (df_eval_backprop_deriv σ df2 d6 - (fun j : {n' : nat | n' < n} => (d0 * grad2 j)%R)) - ; [intros | tauto]. - simpl in IHdf2. - replace - (fun i : {n' : nat | n' < n} => (d0 * grad1 i + d0 * grad2 i)%R) with - (fun j : {n' : nat | n' < n} => (d0 * (grad1 j + grad2 j))%R) in IHdf2. - * rewrite eqq10, H15, H17 in IHdf2. - unfold lift, lift2 in IHdf2; simpl in IHdf2. - inversion IHdf2. - f_equal. - rewrite (split_subvar d2 d10 d d7); trivial. - rewrite (split_subvar d4 d11 d3 d8); trivial. - rewrite (split_subvar d6 d12 d5 d9); trivial. - rewrite H10, H19; lra. - * apply FunctionalExtensionality.functional_extensionality; intros. - lra. - + simpl. - rewrite vsum_plus. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - - Case "MatrixScalMult"%string. - destruct H2. - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df1); intros. - specialize (H4 H2). - match_option; [|tauto]. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H5 H3). - match_option; [|tauto]. - assert (df_eval_backprop_deriv - σ df1 grad_env3 - (msum - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (d1 i j * (grad1 i j + grad2 i j))%R)) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env1 - (msum - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (d1 i j * grad1 i j)%R)) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df1 grad_env2 - (msum - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (d1 i j * grad2 i j)%R)) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H2. - match_option; [|tauto]. - specialize (IHdf1 - (msum - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (d1 i j * grad1 i j)%R)) - (msum - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (d1 i j * grad2 i j)%R)) - grad_env1 grad_env2 grad_env3 H H0 H1 H2). - specialize (IHdf2 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (grad1 i j * d0)%R) - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (grad2 i j * d0)%R) - d4 d6 d2). - rewrite eqq, eqq3, eqq5 in IHdf1. - simpl in IHdf1. - replace - (@msum floatish_R n m - (fun (i : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) => - d1 i j * grad1 i j) + - @msum floatish_R n m - (fun (i : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) => - d1 i j * grad2 i j))%R with - (msum - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (d1 i j * (grad1 i j + grad2 i j))%R)) in IHdf1. - + rewrite eqq2, eqq4, eqq6 in IHdf1. - unfold lift, lift2 in IHdf1; simpl in IHdf1. - inversion IHdf1. - assert (vartlookup d2 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) H1). - assert (vartlookup d4 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s, DTfloat) H). - assert (vartlookup d6 (s, DTfloat) <> None). - apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq6 (s, DTfloat) H0). - specialize (IHdf2 H11 H12 H9 H3). - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - match_option_in IHdf2; [|tauto]. - unfold lift, lift2; simpl. - assert (df_eval_backprop_deriv - σ df2 d2 - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - ((grad1 i j + grad2 i j) * d0)%R) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - match_option; [|tauto]. - assert (df_eval_backprop_deriv σ df2 d4 - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad1 i j * d0)%R) - <> None). - apply backprop_deriv_fully_closed_not_none. - apply H3. - case_eq (df_eval_backprop_deriv σ df2 d4 - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad1 i j * d0)%R)) - ; [intros | tauto]. - assert (df_eval_backprop_deriv σ df2 d6 - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad2 i j * d0)%R) - <> None ). - apply backprop_deriv_fully_closed_not_none. - apply H3. - case_eq (df_eval_backprop_deriv σ df2 d6 - (fun (i : {n' : nat | n' < n}) - (j : {m' : nat | m' < m}) => (grad2 i j * d0)%R)) - ; [intros | tauto]. - simpl in IHdf2. - replace - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (grad1 i j * d0 + grad2 i j * d0)%R) with - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - ((grad1 i j + grad2 i j) * d0)%R) in IHdf2. - * rewrite eqq10, H15, H17 in IHdf2. - unfold lift, lift2 in IHdf2; simpl in IHdf2. - inversion IHdf2. - f_equal. - rewrite (split_subvar d2 d10 d d7); trivial. - rewrite (split_subvar d4 d11 d3 d8); trivial. - rewrite (split_subvar d6 d12 d5 d9); trivial. - rewrite H10, H19; lra. - * apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - + unfold msum. - rewrite vsum_plus. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vmap_nth. - rewrite vmap_nth. - rewrite vmap_nth. - rewrite vsum_plus. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - - Case "VectorApply"%string. - destruct H2. - simpler2. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H4 H3). - match_option. - match_option. - + match_option. - * match_option. - -- unfold lift. - repeat simpl_closed_backprop. - unfold lift2. - specialize (apply vectoro_to_ovector_forall_some_f eqq3);intros. - specialize (apply vectoro_to_ovector_forall_some_f eqq4);intros. - specialize (apply vectoro_to_ovector_forall_some_f eqq5);intros. - specialize (IHdf2 v1 v2 grad_env1 grad_env2 grad_env3). - cut_to IHdf2; try congruence. - rewrite eqq, eqq0, eqq1, eqq7, eqq8 in IHdf2; simpl in IHdf2. - replace (fun i : {n' : nat | n' < n} => (v1 i + v2 i)%R) with v0 in IHdf2. - rewrite eqq6 in IHdf2. - unfold lift, lift2 in IHdf2. - apply IHdf2. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H5 x);rewrite vmap_nth in H5; simpl in H5. - specialize (H6 x);rewrite vmap_nth in H6; simpl in H6. - specialize (H7 x);rewrite vmap_nth in H7; simpl in H7. - match_option_in H5; invcs H5. - match_option_in H6; invcs H6. - match_option_in H7; invcs H7. - rewrite eqq9 in eqq10; invcs eqq10. - rewrite eqq9 in eqq11; invcs eqq11. - lra. - -- specialize (apply vectoro_to_ovector_exists_None eqq5); intros. - destruct H5. - rewrite vmap_nth in e; simpl in e. - match_option_in e. - assert (df_eval [mk_env_entry (v, DTfloat) (d x)] (df_deriv df1 (v, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - * specialize (apply vectoro_to_ovector_exists_None eqq4); intros. - destruct H5. - rewrite vmap_nth in e; simpl in e. - match_option_in e. - assert (df_eval [mk_env_entry (v, DTfloat) (d x)] (df_deriv df1 (v, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - + specialize (apply vectoro_to_ovector_exists_None eqq3); intros. - destruct H5. - rewrite vmap_nth in e; simpl in e. - match_option_in e. - assert (df_eval [mk_env_entry (v, DTfloat) (d x)] (df_deriv df1 (v, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - - Case "MatrixApply"%string. - destruct H2. - simpler2. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H4 H3). - match_option. - match_option. - + match_option. - * match_option. - -- unfold lift. - repeat simpl_closed_backprop. - unfold lift2. - unfold matrixo_to_omatrix in *. - specialize (apply vectoro_to_ovector_forall_some_f eqq3);intros. - specialize (apply vectoro_to_ovector_forall_some_f eqq4);intros. - specialize (apply vectoro_to_ovector_forall_some_f eqq5);intros. - specialize (IHdf2 m1 m2 grad_env1 grad_env2 grad_env3). - cut_to IHdf2; try congruence. - rewrite eqq, eqq0, eqq1, eqq7, eqq8 in IHdf2; simpl in IHdf2. - replace (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (m1 i j + m2 i j)%R) with m0 in IHdf2. - rewrite eqq6 in IHdf2. - unfold lift, lift2 in IHdf2. - apply IHdf2. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H5 x); specialize (H6 x); specialize (H7 x). - unfold mmap in H5; unfold mmap in H6; unfold mmap in H7. - specialize (apply vectoro_to_ovector_forall_some_f H5);intros. - specialize (apply vectoro_to_ovector_forall_some_f H6);intros. - specialize (apply vectoro_to_ovector_forall_some_f H7);intros. - specialize (H8 x0); do 2 rewrite vmap_nth in H8. - specialize (H9 x0); do 2 rewrite vmap_nth in H9. - specialize (H10 x0); do 2 rewrite vmap_nth in H10. - rewrite matrix_zip_m_n in H8. - rewrite matrix_zip_m_n in H9. - rewrite matrix_zip_m_n in H10. - match_option_in H8. - rewrite eqq9 in H9; rewrite eqq9 in H10. - invcs H8; invcs H9; invcs H10. - lra. - -- specialize (apply vectoro_to_ovector_exists_None eqq5); intros; destruct H5. - specialize (apply vectoro_to_ovector_exists_None e); intros; destruct H5. - unfold mmap in e0. - do 2 rewrite vmap_nth in e0; simpl in e0. - rewrite matrix_zip_m_n in e0. - match_option_in e0. - assert (df_eval [mk_env_entry (v, DTfloat) (d x x0)] (df_deriv df1 (v, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - * specialize (apply vectoro_to_ovector_exists_None eqq4); intros; destruct H5. - specialize (apply vectoro_to_ovector_exists_None e); intros; destruct H5. - unfold mmap in e0. - do 2 rewrite vmap_nth in e0; simpl in e0. - rewrite matrix_zip_m_n in e0. - match_option_in e0. - assert (df_eval [mk_env_entry (v, DTfloat) (d x x0)] (df_deriv df1 (v, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - + specialize (apply vectoro_to_ovector_exists_None eqq3); intros; destruct H5. - specialize (apply vectoro_to_ovector_exists_None e); intros; destruct H5. - unfold mmap in e0. - do 2 rewrite vmap_nth in e0; simpl in e0. - rewrite matrix_zip_m_n in e0. - match_option_in e0. - assert (df_eval [mk_env_entry (v, DTfloat) (d x x0)] (df_deriv df1 (v, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - - Case "VLossfun"%string. - destruct H2. - simpler2. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H4 H3). - match_option. - match_option. - + match_option. - * match_option. - -- unfold lift. - repeat simpl_closed_backprop. - unfold lift2. - specialize (apply vectoro_to_ovector_forall_some_f eqq3);intros. - specialize (apply vectoro_to_ovector_forall_some_f eqq4);intros. - specialize (apply vectoro_to_ovector_forall_some_f eqq5);intros. - specialize (IHdf2 v0 v3 grad_env1 grad_env2 grad_env3). - cut_to IHdf2; try congruence. - rewrite eqq, eqq0, eqq1 in IHdf2. - rewrite eqq7, eqq8 in IHdf2. - simpl in IHdf2. - replace (fun i : {n' : nat | n' < n} => (v0 i + v3 i)%R) with v in IHdf2. - rewrite eqq6 in IHdf2. - unfold lift, lift2 in IHdf2. - apply IHdf2. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H5 x);rewrite vmap_nth in H5; simpl in H5. - specialize (H6 x);rewrite vmap_nth in H6; simpl in H6. - specialize (H7 x);rewrite vmap_nth in H7; simpl in H7. - match_option_in H5. - rewrite eqq9 in H6; rewrite eqq9 in H7. - invcs H5; invcs H6; invcs H7. - lra. - -- specialize (apply vectoro_to_ovector_exists_None eqq5); intros. - destruct H5. - rewrite vmap_nth in e; simpl in e. - match_option_in e. - assert ( df_eval [mk_env_entry (v1, DTfloat) (d x); - mk_env_entry (v2, DTfloat) (r x)] - (df_deriv df1 (v1, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - * specialize (apply vectoro_to_ovector_exists_None eqq4); intros. - destruct H5. - rewrite vmap_nth in e; simpl in e. - match_option_in e. - assert (df_eval [mk_env_entry (v1, DTfloat) (d x); - mk_env_entry (v2, DTfloat) (r x)] - (df_deriv df1 (v1, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - + specialize (apply vectoro_to_ovector_exists_None eqq3); intros. - destruct H5. - rewrite vmap_nth in e; simpl in e. - match_option_in e. - assert (df_eval [mk_env_entry (v1, DTfloat) (d x); mk_env_entry (v2, DTfloat) (r x)] - (df_deriv df1 (v1, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - - Case "MLossfun"%string. - destruct H2. - simpler2. - generalize (eval_fully_closed_not_none σ df2); intros. - specialize (H4 H3). - match_option. - match_option. - + match_option. - * match_option. - -- unfold lift. - repeat simpl_closed_backprop. - unfold lift2. - specialize (apply vectoro_to_ovector_forall_some_f eqq3);intros. - specialize (apply vectoro_to_ovector_forall_some_f eqq4);intros. - specialize (apply vectoro_to_ovector_forall_some_f eqq5);intros. - specialize (IHdf2 m1 m2 grad_env1 grad_env2 grad_env3). - cut_to IHdf2; try congruence. - rewrite eqq, eqq0, eqq1 in IHdf2. - rewrite eqq7, eqq8 in IHdf2. - simpl in IHdf2. - replace (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (m1 i j + m2 i j)%R) with m0 in IHdf2. - rewrite eqq6 in IHdf2. - unfold lift, lift2 in IHdf2. - apply IHdf2. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H5 x); simpl in H5. - specialize (H6 x); simpl in H6. - specialize (H7 x); simpl in H7. - specialize (apply vectoro_to_ovector_forall_some_f H5);intros. - specialize (apply vectoro_to_ovector_forall_some_f H6);intros. - specialize (apply vectoro_to_ovector_forall_some_f H7);intros. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H8 x0); unfold mmap in H8;do 2 rewrite vmap_nth in H8; simpl in H8. - specialize (H9 x0); unfold mmap in H9;do 2 rewrite vmap_nth in H9; simpl in H9. - specialize (H10 x0); unfold mmap in H10;do 2 rewrite vmap_nth in H10; simpl in H10. - rewrite matrix_zip_m_n in H8. - rewrite matrix_zip_m_n in H9. - rewrite matrix_zip_m_n in H10. - match_option_in H8. - rewrite eqq9 in H9; rewrite eqq9 in H10. - invcs H8; invcs H9; invcs H10. - lra. - -- specialize (apply vectoro_to_ovector_exists_None eqq5); intros. - destruct H5. - specialize (apply vectoro_to_ovector_exists_None e); intros. - destruct H5. - unfold mmap in e0; do 2 rewrite vmap_nth in e0; simpl in e0. - rewrite matrix_zip_m_n in e0. - match_option_in e0. - assert ( df_eval [mk_env_entry (v1, DTfloat) (d x x0); - mk_env_entry (v2, DTfloat) (r x x0)] - (df_deriv df1 (v1, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - * specialize (apply vectoro_to_ovector_exists_None eqq4); intros. - destruct H5. - specialize (apply vectoro_to_ovector_exists_None e); intros. - destruct H5. - unfold mmap in e0; do 2 rewrite vmap_nth in e0; simpl in e0. - rewrite matrix_zip_m_n in e0. - match_option_in e0. - assert (df_eval [mk_env_entry (v1, DTfloat) (d x x0); - mk_env_entry (v2, DTfloat) (r x x0)] - (df_deriv df1 (v1, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - + specialize (apply vectoro_to_ovector_exists_None eqq3); intros. - destruct H5. - specialize (apply vectoro_to_ovector_exists_None e); intros. - destruct H5. - unfold mmap in e0; do 2 rewrite vmap_nth in e0; simpl in e0. - rewrite matrix_zip_m_n in e0. - match_option_in e0. - assert (df_eval [mk_env_entry (v1, DTfloat) (d x x0); - mk_env_entry (v2, DTfloat) (r x x0)] - (df_deriv df1 (v1, DTfloat)) <> None). - apply eval_fully_closed_not_none. - now apply fully_closed_deriv. - tauto. - Qed. - - Lemma vectoro_to_ovector_eNone_None {A n} {vo:Vector (option A) n} : - {i | vo i = None} -> - vectoro_to_ovector vo = None. - Proof. - intros. - destruct H. - now apply (vectoro_to_ovector_None_None x). - Qed. - - Definition ConstSplitVector {T} (middle theend:nat) (part1 part2:T) : (Vector T theend) := - fun (i: {n':nat | n' < theend}%nat) => - if lt_dec (proj1_sig i) middle then part1 else part2. - - Definition mergeVectorZero {T} {n} (middle:nat) (part1 : Vector T n) (c:T) : (Vector T n) := - fun (i: {n':nat | n' < n}%nat) => - if lt_dec (proj1_sig i) middle then (part1 i) else c. - - Definition scaleUnitVector {T} (n:nat) (j : {n':nat | (n' < n)%nat}) (c:T) (zero:T) : Vector T n := - fun i => if (proj1_sig i) == (proj1_sig j) then c else zero%R. - - Lemma ConstSplitVectorSzero bound n (pf:bound < n) : - (ConstSplitVector (S bound) n 1%R 0%R) = - dfti_gen_plus (T:=DTVector n) (ConstSplitVector bound n 1%R 0%R) (UnitVector n (exist _ bound pf)). - Proof. - simpl. - apply functional_extensionality. - intros. - unfold ConstSplitVector, UnitVector, equiv_dec, nat_eq_eqdec; simpl. - destruct x as [x pff]; simpl. - destruct (lt_dec x (S bound)) - ; destruct (lt_dec x bound) - ; destruct (Nat.eq_dec x bound) - ; try lia; try lra. - Qed. - - Lemma mergeVectorSzero {n} (bound:nat) (pf:bound < n) (part1 : Vector R n): - let ind := (exist _ bound pf) in - (@mergeVectorZero (@float floatish_R) n (S bound) part1 0%R) = - dfti_gen_plus (T:=DTVector n) (@mergeVectorZero (@float floatish_R) n bound part1 0%R) - (scaleUnitVector n ind (part1 ind) 0%R). - Proof. - simpl. - apply functional_extensionality. - intros. - unfold mergeVectorZero, scaleUnitVector, equiv_dec, nat_eq_eqdec; simpl. - destruct x as [x pff]; simpl. - destruct (lt_dec x (S bound)) - ; destruct (lt_dec x bound) - ; destruct (Nat.eq_dec x bound) - ; try lia; try lra. - subst; simpl. - ring_simplify. - erewrite index_pf_irrel; eauto. - Qed. - - Lemma mergeVectorSzero_mat {n m} (bound:nat) pf (part1 : Matrix float n m) : - let ind := (exist _ bound pf) in - (@mergeVectorZero (Vector float m) n (S bound) part1 (ConstVector m 0%R)) = - dfti_gen_plus (T:=DTMatrix n m) (@mergeVectorZero (Vector float m) n bound part1 (ConstVector m 0%R)) - (scaleUnitVector (T:=Vector float m) n ind (part1 ind) (ConstVector m 0%R)). - Proof. - simpl. - do 2 (apply functional_extensionality; intros). - unfold mergeVectorZero, scaleUnitVector, ConstVector, equiv_dec, nat_eq_eqdec; simpl. - destruct x as [x pff]; simpl. - destruct (lt_dec x (S bound)) - ; destruct (lt_dec x bound) - ; destruct (Nat.eq_dec x bound) - ; simpl - ; try lia; try lra. - subst; simpl. - rewrite Rplus_0_l. - erewrite index_pf_irrel; eauto. - Qed. - - Lemma vsum_alt_eq {m:nat} (v:Vector R m) : vsum v = vector_fold_right Fplus 0%R v. - Proof. - apply vector_fold_right1_as_vector_fold_right. - unfold Datatypes.id; simpl; intros; lra. - Qed. - - Lemma vsum_cons {m:nat} x (v:Vector R m) : - vsum (vcons x v) = (x + vsum v)%R. - Proof. - repeat rewrite vsum_alt_eq. - apply vector_fold_right_vcons. - Qed. - - Lemma constSplitVectorZero {n} : - ConstSplitVector 0 n 1%R 0%R = ConstVector n 0%R. - Proof. - unfold ConstSplitVector, ConstVector; simpl. - now apply functional_extensionality. - Qed. - - Lemma df_eval_backprop_delta_by_unit_partvec_bounded {n} bound (pf:(bound<=n)%nat) - (σ:df_env) (df:DefinedFunction UnitAnn (DTVector n)) (s: SubVar) grad_env (grad d:Vector float n): - - fully_closed_over df - (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) -> - vartlookup grad_env (s, DTfloat) <> None -> - - (forall i : {n' : nat | n' < n}, - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - (scaleUnitVector n i (grad i) 0%R)) = Some (d i)) -> - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - (mergeVectorZero bound grad 0%R)) = Some (vsum (vfirstn d bound pf)). - Proof. - intros closed lo fa. - induction bound. - - replace (mergeVectorZero 0 grad 0%R) with (scalarMult (DTVector n) 0%R (mergeVectorZero 0 grad 0%R)). - + erewrite scalarMult_backprop_grad_scalar; try eassumption. - * simpl. - unfold df_eval_backprop_delta. - simpler2. - unfold lift. - simpl_closed_backprop. - f_equal. - vm_compute; lra. - * apply backprop_deriv_fully_closed_not_none; trivial. - * apply backprop_deriv_fully_closed_not_none; trivial. - + unfold scalarMult, mergeVectorZero. - apply functional_extensionality; intros; simpl. - lra. - - assert (pf2:bound < n) by lia. - assert (pf3:bound <= n) by lia. - rewrite (mergeVectorSzero _ pf2 _ ). - erewrite backprop_grad_sum; try eassumption. - specialize (IHbound pf3). - rewrite IHbound. - rewrite fa. - simpl. - f_equal. - destruct n; [lia | ]. - generalize (vector_Sn_split (vfirstn d (S bound) pf)); intros eqq1. - apply vec_eq_eq in eqq1. - simpl in *. - rewrite eqq1. - erewrite vfirstn_vdrop_last. - erewrite vlast_vfirstn. - rewrite vsum_cons. - rewrite Rplus_comm. - f_equal. - Qed. - - Lemma df_eval_backprop_delta_by_unit_partvec {n} - (σ:df_env) (df:DefinedFunction UnitAnn (DTVector n)) (s: SubVar) grad_env (grad d:Vector float n): - - fully_closed_over df - (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) -> - vartlookup grad_env (s, DTfloat) <> None -> - - (forall i : {n' : nat | n' < n}, - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - (scaleUnitVector n i (grad i) 0%R)) = Some (d i)) -> - df_eval_backprop_delta σ df (s, DTfloat) grad_env grad = - Some (vsum d). - Proof. - intros. - replace (grad) with (mergeVectorZero n grad 0%R). - - erewrite df_eval_backprop_delta_by_unit_partvec_bounded; try eassumption. - now rewrite vfirstn_eq. - - apply functional_extensionality; unfold mergeVectorZero; intros [x pff]; simpl. - destruct (lt_dec x n); trivial; lia. - Unshelve. - lia. - Qed. - - Lemma df_eval_backprop_delta_by_unit_parts {n} - (σ:df_env) (df:DefinedFunction UnitAnn (DTVector n)) (s: SubVar) grad_env (d:Vector float n): - - fully_closed_over df - (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) -> - vartlookup grad_env (s, DTfloat) <> None -> - - (forall i : {n' : nat | n' < n}, - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - - (UnitVector n i)) = Some (d i)) -> - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - (ConstVector n 1%R)) = Some (vsum d). - Proof. - intros. - replace (ConstVector n 1%R) with (mergeVectorZero n (ConstVector n 1%R) 0%R). - - erewrite df_eval_backprop_delta_by_unit_partvec_bounded; try eassumption. - now rewrite vfirstn_eq. - - apply functional_extensionality; unfold mergeVectorZero, ConstVector; intros [x pff]; simpl. - destruct (lt_dec x n); trivial; lia. - Unshelve. - lia. - Qed. - - Lemma scalarMult_mult {T} a b grad : scalarMult T a (scalarMult T b grad) = scalarMult T (a*b)%R grad. - Proof. - destruct T; simpl. - - lra. - - apply vec_eq_eq; intros ?; lra. - - do 2 (apply vec_eq_eq; intros ?); lra. - Qed. - - Corollary scalarMult_backprop_grad0 {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (s: SubVar) (grad_env :df_env) (grad : definition_function_types_interp T) : - let v := (s, DTfloat) in - vartlookup grad_env v <> None -> - df_eval_backprop_deriv σ df grad_env (scalarMult T 0%R grad) <> None -> - df_eval_backprop_delta σ df v grad_env (scalarMult T 0%R grad) = Some 0%R. - Proof. - simpl; intros. - (* This allows us to drop the (unneeded) assumption - df_eval_backprop_deriv σ df grad_env1 grad <> None - *) - replace (scalarMult T 0%R grad) with (scalarMult T 0%R (scalarMult T 0%R grad)). - - erewrite scalarMult_backprop_grad_scalar; try eassumption. - + simpl. - unfold df_eval_backprop_delta. - simpler2; simpl. - destruct (df_eval_backprop_deriv σ df grad_env (scalarMult T 0%R grad)); simpl; [| congruence]. - f_equal; lra. - + now rewrite scalarMult_mult, Rmult_0_l. - - now rewrite scalarMult_mult, Rmult_0_l. - Qed. - - Lemma scaleUnitVec_vec_plus_distr m n i (x y:Vector float m) c : - vec_eq c (dfti_gen_plus (T:=DTVector m) c c) -> - (scaleUnitVector n i (dfti_gen_plus (T:=DTVector m) x y) c) = - (dfti_gen_plus (T:=DTMatrix n m) (scaleUnitVector n i x c) (scaleUnitVector n i y c)). - Proof. - intros fa. - do 2 (apply functional_extensionality; intro). - unfold scaleUnitVector; simpl. - match_destr. - now specialize (fa x1); simpl in fa. - Qed. - - - Lemma df_eval_backprop_delta_by_unit_partmat_outer_bounded {n m} bound (pf:(bound<=n)%nat) - (σ:df_env) (df:DefinedFunction UnitAnn (DTMatrix n m)) (s: SubVar) grad_env (grad d:Matrix float n m): - - - fully_closed_over df - (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) -> - vartlookup grad_env (s, DTfloat) <> None -> - - (forall (i : {n' : nat | n' < n}), - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - (scaleUnitVector n i (grad i) (ConstVector m 0%R))) = Some (vsum (d i))) -> - - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - (mergeVectorZero bound grad (ConstVector m 0%R))) = Some (vsum (vfirstn (vmap vsum d) bound pf)). - Proof. - intros closed lo fa. - induction bound. - - rewrite vfirstn0, vsum_nil. - - replace (mergeVectorZero 0 grad (ConstVector m 0%R)) with (scalarMult (DTMatrix n m) 0%R (mergeVectorZero 0 grad (ConstVector m 0%R))). - + apply scalarMult_backprop_grad0; simpl in *; trivial. - apply backprop_deriv_fully_closed_not_none; trivial. - + unfold scalarMult, ConstVector, mergeVectorZero. - do 2 (apply functional_extensionality; intros); simpl. - lra. - - assert (pf2:bound < n) by lia. - assert (pf3:bound <= n) by lia. - simpl. - generalize (mergeVectorSzero_mat (m:=m) _ pf2 grad ); intros HH; unfold Vector in HH. - simpl float in *. - rewrite HH; clear HH. - erewrite backprop_grad_sum; try eassumption. - specialize (IHbound pf3). - rewrite IHbound. - rewrite fa. - simpl. - f_equal. - generalize (vector_Sn_split (vfirstn (vmap vsum d) (S bound) pf)); intros eqq1. - apply vec_eq_eq in eqq1. - simpl in *. - rewrite eqq1. - erewrite vfirstn_vdrop_last. - erewrite vlast_vfirstn. - rewrite vsum_cons. - rewrite Rplus_comm. - f_equal. - rewrite vmap_nth. - eauto. - Qed. - - Lemma df_eval_backprop_delta_by_unit_partmat_inner_bounded {n m} bound (pf:(bound<=m)%nat) - (σ:df_env) (df:DefinedFunction UnitAnn (DTMatrix n m)) (s: SubVar) grad_env (grad d:Matrix float n m) i: - - - fully_closed_over df - (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) -> - vartlookup grad_env (s, DTfloat) <> None -> - - (forall j : {n' : nat | n' < m}, - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - ((scaleUnitVector n i - (scaleUnitVector m j (grad i j) 0%R) (ConstVector m 0%R)))) = Some (d i j)) -> - - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - (scaleUnitVector n i (mergeVectorZero bound (grad i) 0%R) (ConstVector m 0%R))) = Some (vsum (vfirstn (d i) bound pf)). - Proof. - intros closed lo fa. - induction bound. - - rewrite vfirstn0, vsum_nil. - replace (scaleUnitVector n i (mergeVectorZero 0 (grad i) 0%R) (ConstVector m 0%R)) with (scalarMult (DTMatrix n m) 0%R (mergeVectorZero 0 grad (ConstVector m 0%R))). - + apply scalarMult_backprop_grad0; simpl in *; trivial. - apply backprop_deriv_fully_closed_not_none; trivial. - + unfold scalarMult, ConstVector, mergeVectorZero, scaleUnitVector. - do 2 (apply functional_extensionality; intros); simpl. - match_destr; lra. - - assert (pf2:bound < m) by lia. - assert (pf3:bound <= m) by lia. - simpl. - generalize (mergeVectorSzero _ pf2 (grad i) ); intros HH; unfold Vector in HH. - simpl float in *. - rewrite HH; clear HH. - rewrite (scaleUnitVec_vec_plus_distr m n i) by (intros ?; unfold ConstVector; simpl; lra). - erewrite backprop_grad_sum; try eassumption. - specialize (IHbound pf3). - simpl in *. - rewrite IHbound. - simpl. - rewrite fa. - simpl. - f_equal. - generalize (vector_Sn_split (vfirstn (d i) (S bound) pf)); intros eqq1. - apply vec_eq_eq in eqq1. - simpl in *. - rewrite eqq1. - erewrite vfirstn_vdrop_last. - erewrite vlast_vfirstn. - rewrite vsum_cons. - rewrite Rplus_comm. - f_equal. - Qed. - - Lemma df_eval_backprop_delta_by_unit_partmat {n m} - (σ:df_env) (df:DefinedFunction UnitAnn (DTMatrix n m)) (s: SubVar) grad_env - (grad d:Matrix float n m): - fully_closed_over - df - (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) - σ) -> - vartlookup grad_env (s, DTfloat) <> None -> - (forall (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) , - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - (scalarMult (DTMatrix n m) (grad i j) - (UnitMatrix n m i j))) = - Some (d i j)) -> - df_eval_backprop_delta σ df (s, DTfloat) grad_env grad = - Some (msum d). - Proof. - intros closed lo fa. - replace (grad) with (mergeVectorZero n grad (ConstVector m 0%R)). - - erewrite df_eval_backprop_delta_by_unit_partmat_outer_bounded; try eassumption. - + now rewrite vfirstn_eq. - + intros i. - specialize (fa i). - simpl in fa. - replace (grad i) with (mergeVectorZero m (grad i) 0%R). - 2: { - unfold mergeVectorZero; simpl; apply functional_extensionality; intros; destruct x. - simpl. - match_destr; lia. - } - - erewrite (df_eval_backprop_delta_by_unit_partmat_inner_bounded m (le_refl m) σ df s grad_env) - ; try eassumption. - * now rewrite vfirstn_eq. - * intros j. - specialize (fa j); simpl in *. - rewrite <- fa. - f_equal. - do 2 (apply functional_extensionality; intros). - unfold scaleUnitVector, ConstVector, UnitMatrix; simpl. - repeat match_destr; lra. - - apply functional_extensionality; unfold mergeVectorZero; intros [x pff]; simpl. - destruct (lt_dec x n); trivial; lia. - Unshelve. - lia. - Qed. - - Lemma df_eval_backprop_delta_by_unit_parts_mat {n m} - (σ:df_env) (df:DefinedFunction UnitAnn (DTMatrix n m)) (s: SubVar) grad_env - (d:Matrix float n m): - fully_closed_over - df - (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) - σ) -> - vartlookup grad_env (s, DTfloat) <> None -> - - (forall (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) , - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - - (UnitMatrix n m i j)) = Some (d i j)) -> - (df_eval_backprop_delta σ df (s, DTfloat) grad_env - (ConstMatrix n m 1%R)) = Some (msum d). - Proof. - intros. - apply df_eval_backprop_delta_by_unit_partmat; trivial. - intros. - rewrite scalarMult_backprop_grad_scalar with (grad_env2 := grad_env); trivial. - unfold lift. - rewrite H1. - f_equal. - unfold scalarMult, ConstMatrix; simpl; lra. - apply backprop_deriv_fully_closed_not_none; trivial. - apply backprop_deriv_fully_closed_not_none; trivial. - Qed. - - Corollary scalarMult_backprop_list_env_iter_grad0 {T} (σ:df_env) (s: SubVar) (grad_env :df_env) (grad : definition_function_types_interp T) old n x l : - let v := (s, DTfloat) in - (forall j : {n' : nat | n' < n}, - fully_closed_over (x j) - (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ)) -> - vartlookup grad_env v = Some old -> - lift (fun e => subvar v e old) - (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv (Ann:=UnitAnn) σ (x i) env (scalarMult T 0%R grad)) (Some grad_env) l) = Some 0%R. - Proof. - simpl; intros. - revert grad_env old H0. - induction l; simpl; intros. - - unfold subvar; simpl. - rewrite H0; f_equal. - lra. - - unfold lift in *. - case_eq (df_eval_backprop_deriv σ (x a) grad_env (scalarMult T 0%R grad)) - ; intros. - + apply IHl. - * generalize (scalarMult_backprop_grad0 σ (x a) s grad_env grad); intros HH. - simpl in HH. - cut_to HH; try congruence. - unfold df_eval_backprop_delta in HH. - rewrite H0 in HH. - rewrite H1 in HH. - simpl in HH. - invcs HH. - unfold subvar in H3; simpl in H3. - simpler2. - rewrite eqq in eqq0; invcs eqq0; subst. - f_equal. - lra. - + eelim backprop_deriv_fully_closed_not_none; eauto. - Qed. - - Lemma list_env_iter_gen_delta0 {n} (s: SubVar) (init_env : df_env) - (f: {n' : nat | n' < n} -> df_env -> option df_env) (l : list {n' : nat | n' < n}): - let v := (s,DTfloat) in - vartlookup init_env v <> None -> - (forall (i : {n' : nat | n' < n}) (env : df_env), - (f i env) <> None /\ - (vartlookup env v <> None -> - match f i env with - | Some xenv => vartlookup xenv v <> None - | _ => True - end ) /\ - (In i l -> vartlookup env v <> None -> - lift2 (fun e val => subvar v e val) (f i env) - (vartlookup env v) = Some 0%R)) -> - lift2 (fun e old => subvar v e old) - (list_env_iter f (Some init_env) l) - (vartlookup init_env v) = Some 0%R. - Proof. - simpl; intros. - revert init_env H H0. - induction l. - - intros. - simpl; f_equal. - unfold subvar; simpl. - match_option; [|tauto]. - f_equal. - lra. - - intros. - assert (H0c := H0). - specialize (H0 a init_env). - destruct H0; destruct H1. - simpl. - case_eq (f a init_env); [intros|tauto]. - specialize (IHl d). - replace (vartlookup init_env (s, DTfloat)) with (vartlookup d (s, DTfloat)). - + apply IHl. - * specialize (H1 H). - now rewrite H3 in H1. - * intros. - specialize (H0c i env). - destruct H0c; destruct H5. - split; trivial. - split; trivial. - intros. - cut_to H6; trivial. - simpl; tauto. - + case_eq (vartlookup init_env (s, DTfloat)); [intros|tauto]. - rewrite H3, H4 in H2; simpl in H2. - cut_to H2; try tauto. - unfold lift2 in H2; simpl in H2. - inversion H2. - unfold subvar in H6; simpl in H6. - rewrite H3 in H1. - specialize (H1 H). - case_eq ( vartlookup d (s, DTfloat)); [intros|tauto]. - rewrite H5 in H6. - f_equal; lra. - congruence. - Qed. - - Lemma list_env_iter_gen_delta {n} (s: SubVar) (init_env : df_env) (old : float) - (f: {n' : nat | n' < n} -> df_env -> option df_env) (i0 : {n' : nat | n' < n}): - let v := (s,DTfloat) in - vartlookup init_env v = Some old -> - (forall (i : {n' : nat | n' < n}) (env : df_env), - (f i env) <> None /\ - (vartlookup env v <> None -> - match f i env with - | Some xenv => vartlookup xenv v <> None - | _ => True - end ) /\ - (forall (env2 : df_env), - match vartlookup env v, vartlookup env2 v, f i env, f i env2 with - | Some val1, Some val2, Some xenv, Some xenv2 => - subvar v xenv val1 = subvar v xenv2 val2 - | _, _, _, _ => True - end) /\ - ( i <> i0 -> vartlookup env v <> None -> - lift2 (fun e val => subvar v e val) (f i env) - (vartlookup env v) = Some 0%R)) -> - lift (fun e => subvar v e old) (f i0 init_env) = - lift (fun e => subvar v e old) - (list_env_iter f - (Some init_env) (bounded_seq0 n)). - Proof. - simpl; intros. - unfold bounded_seq0. - destruct (bounded_seq_break_at 0 n i0) as [b [c [eqq1 [fa1 fa2]]]]; [lia |]. - rewrite eqq1. - rewrite list_env_iter_app; simpl. - match_option. - - - assert (eqq': subvar (s, DTfloat) d old = 0%R). - + generalize (list_env_iter_gen_delta0 s init_env f b); intros. - simpl in H1. - cut_to H1; try congruence. - * unfold lift2 in H1. - rewrite H, eqq in H1. - now inversion H1. - * intros. - specialize (H0 i env). - destruct H0; destruct H2; destruct H3. - split; trivial. - split; trivial. - intros. - cut_to H4; trivial. - rewrite Forall_forall in fa1. - specialize (fa1 i H5). - intro eq1; rewrite eq1 in fa1; lia. - + generalize (vartlookup_list_env_iter s f b); intros vart. - * specialize (vart init_env d eqq). - assert (vartinit: vartlookup init_env (s, DTfloat) <> None) by congruence. - specialize (vart vartinit). - assert (f i0 d <> None) by apply H0. - case_eq (f i0 d); [intros | tauto]. - generalize (list_env_iter_gen_delta0 s d0 f c); simpl; intros. - cut_to H3; try congruence. - -- unfold lift at 2. - unfold lift2 in H3. - match_option. - ++ rewrite eqq0 in H3. - match_option_in H3. - rewrite (split_subvar d0 d1 old d2);trivial. - inversion H3. - rewrite H5. - unfold lift. - match_option. - ** f_equal. - case_eq (vartlookup d (s, DTfloat)); intros. - --- rewrite (split_subvar d d0 old d4); trivial. - rewrite eqq'. - specialize (H0 i0 init_env). - destruct H0; destruct H6; destruct H7. - specialize (H7 d). - rewrite H, H4, eqq3, H2 in H7. - rewrite H7; lra. - --- cut_to vart; [tauto|]. - intros. - specialize (H0 i env). - destruct H0; destruct H8. - specialize (H8 H7). - now rewrite H6 in H8. - ** specialize (H0 i0 init_env). - destruct H0; congruence. - ++ generalize (list_env_iter_total_fun f d0 c); intros. - cut_to H4; try congruence. - apply H0. - -- cut_to vart. - ++ specialize (H0 i0 d). - destruct H0; destruct H4. - rewrite H2 in H4. - apply H4; trivial. - ++ intros. - specialize (H0 i env). - destruct H0;destruct H6. - specialize (H6 H5). - now rewrite H4 in H6. - -- intros. - specialize (H0 i env). - split. - apply H0. - split. - apply H0. - intros. - assert (i <> i0). - ++ rewrite Forall_forall in fa2. - specialize (fa2 i H4). - intro eq1; rewrite eq1 in fa2; lia. - ++ now apply H0. - - generalize (list_env_iter_total_fun f init_env b); intros. - cut_to H1; trivial. - tauto. - intros. - apply H0. - Qed. - - Lemma list_env_iter_vec_delta {n} (σ:df_env) - (x:Vector (DefinedFunction UnitAnn DTfloat) n) (s: SubVar) grad_env - (i0 : {n' : nat | n' < n}) (old : float) : - let v := (s, DTfloat) in - vartlookup grad_env v = Some old -> - (forall (j: {n' : nat | n' < n}) , - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over (x j) vl) -> - lift (fun e => subvar v e old) - (df_eval_backprop_deriv σ (x i0) grad_env 1%R) = - lift (fun e => subvar v e old) - (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (UnitVector n i0 i)) - (Some grad_env) (bounded_seq0 n)). - Proof. - simpl; intros. - generalize (list_env_iter_gen_delta - s grad_env old - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env (UnitVector n i0 i)) - i0). - simpl; intros. - specialize (H1 H). - rewrite <- H1. - - unfold UnitVector; simpl. - now destruct (equiv_dec (` i0) (` i0)); [|congruence]. - - clear H1. - intros. - split; [|split]. - + apply backprop_deriv_fully_closed_not_none; auto. - + intros. - match_option. - now apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in eqq;trivial. - + split. - * intros. - do 4 match_option. - generalize (backprop_indep_env σ (x i) s env env2 - (UnitVector n i0 i)); simpl; intros HH. - cut_to HH; trivial; try congruence. - unfold df_eval_backprop_delta in HH. - rewrite eqq,eqq0,eqq1,eqq2 in HH. - unfold lift in HH. - now inversion HH. - * intros. - unfold UnitMatrix; simpl. - destruct (equiv_dec (` i) (` i0)). - elim H1; destruct i; destruct i0; simpl in *; red in e; subst; apply index_pf_irrel. - unfold lift2. - generalize (scalarMult_backprop_grad0 σ (x i) s env 0%R); simpl; intros. - unfold df_eval_backprop_delta in H3. - specialize (H3 H2). - replace (0 * 0)%R with 0%R in H3 by lra. - match_option. - match_option; [|tauto]. - -- rewrite eqq0 in H3. - replace (UnitVector n i0 i) with Fzero in eqq; simpl in eqq. - ++ rewrite eqq in H3; unfold lift in H3. - apply H3; congruence. - ++ unfold UnitVector; simpl. - destruct (equiv_dec (` i) (` i0)); [|trivial]. - elim H1; destruct i; destruct i0; simpl in *; red in e; subst; apply index_pf_irrel. - -- assert (df_eval_backprop_deriv σ (x i) env (UnitVector n i0 i) <> None). - now apply backprop_deriv_fully_closed_not_none. - tauto. - Qed. - - - Lemma list_env_iter_backprop_indep_env_vec {m} (σ:df_env) - (vecdf: Vector (DefinedFunction UnitAnn DTfloat) m) - (s:SubVar) (env env2:df_env) (grad: Vector float m) - (old1 old2: float) : - let v := (s, DTfloat) in - vartlookup env v = Some old1 -> - vartlookup env2 v = Some old2 -> - (let vl := map (fun ve => projT1 ve) σ in - forall j : {m' : nat | m' < m}, - fully_closed_over (vecdf j) vl) -> - lift (fun e => subvar (s, DTfloat) e old1) - (list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (vecdf j) env0 (grad j)) - (Some env) (bounded_seq0 m)) = - lift (fun e => subvar (s, DTfloat) e old2) - (list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (vecdf j) env0 (grad j)) - (Some env2) (bounded_seq0 m)). - Proof. - intros. - subst v. - revert old1 old2 env env2 H H0. - induction (bounded_seq0 m). - - intros. - simpl. - unfold subvar; simpl. - rewrite H, H0. - f_equal; lra. - - intros. - simpl. - case_eq (df_eval_backprop_deriv σ (vecdf a) env (grad a)); intros. - case_eq (df_eval_backprop_deriv σ (vecdf a) env2 (grad a)); intros. - case_eq (vartlookup d (s, DTfloat)); intros. - case_eq (vartlookup d0 (s, DTfloat)); intros. - + specialize (IHl d1 d2 d d0 H4 H5). - unfold lift. - do 2 match_option. - * rewrite eqq, eqq0 in IHl. - unfold lift in IHl. - inversion IHl. - f_equal. - rewrite (split_subvar d d3 old1 d1); trivial. - rewrite (split_subvar d0 d4 old2 d2); trivial. - rewrite H7. - generalize (backprop_indep_env - σ (vecdf a) s - env env2 (grad a)); simpl; intros. - cut_to H6; trivial; try congruence. - unfold df_eval_backprop_delta in H6. - rewrite H, H0, H2, H3 in H6. - unfold lift in H6. - inversion H6. - rewrite H9; lra. - * generalize (list_env_iter_total_fun - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (vecdf j) env0 (grad j)) - d0 l); intros. - cut_to H6; [tauto|]. - intros. - apply backprop_deriv_fully_closed_not_none; auto. - * generalize (list_env_iter_total_fun - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (vecdf j) env0 (grad j)) - d l); intros. - cut_to H6; [tauto|]. - intros. - apply backprop_deriv_fully_closed_not_none; auto. - + apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in H3 - ;trivial; congruence. - + apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in H2 - ;trivial; congruence. - + assert (df_eval_backprop_deriv σ (vecdf a) env2 (grad a) <> None) by - (apply backprop_deriv_fully_closed_not_none; auto); tauto. - + assert (df_eval_backprop_deriv σ (vecdf a) env (grad a) <> None) by - (apply backprop_deriv_fully_closed_not_none; auto); tauto. - Qed. - - Lemma list_env_iter_mat_delta {n m} (σ:df_env) - (x:Matrix (DefinedFunction UnitAnn DTfloat) n m) (s: SubVar) grad_env - (i0 : {n' : nat | n' < n}) - (j0 : {m' : nat | m' < m}) (old : float) : - let v := (s, DTfloat) in - vartlookup grad_env v = Some old -> - (forall (i: {n' : nat | n' < n}) (j: {m' : nat | m' < m}) , - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over (x i j) vl) -> - lift (fun e => subvar v e old) - (df_eval_backprop_deriv σ (x i0 j0) grad_env 1%R) = - lift (fun e => subvar v e old) - (list_env_iter - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv - σ (x i j) env0 - (UnitMatrix n m i0 j0 i j)) (Some env) - (bounded_seq0 m)) (Some grad_env) (bounded_seq0 n)). - Proof. - simpl; intros. - generalize (list_env_iter_gen_delta - s grad_env old - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv - σ (x i j) env0 - (UnitMatrix n m i0 j0 i j)) (Some env) - (bounded_seq0 m)) i0); simpl. - intros. - specialize (H1 H). - rewrite <- H1. - - clear H1. - generalize (list_env_iter_gen_delta - s grad_env old - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i0 j) env0 (UnitMatrix n m i0 j0 i0 j)) - j0); simpl. - intros. - specialize (H1 H). - rewrite <- H1. - + unfold UnitMatrix; simpl. - destruct (equiv_dec (` i0) (` i0)); [|congruence]. - now destruct (equiv_dec (` j0) (` j0)); [|congruence]. - + clear H1. - intros. - split; [|split]. - * apply backprop_deriv_fully_closed_not_none; auto. - * intros. - match_option. - now apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in eqq;trivial. - * split. - -- intros. - do 4 match_option. - generalize (backprop_indep_env σ (x i0 i) s env env2 - (UnitMatrix n m i0 j0 i0 i)); simpl; intros HH. - cut_to HH; trivial; try congruence. - unfold df_eval_backprop_delta in HH. - rewrite eqq,eqq0,eqq1,eqq2 in HH. - unfold lift in HH. - now inversion HH. - -- intros. - unfold UnitMatrix; simpl. - destruct (equiv_dec (` i0) (` i0)); [|congruence]. - destruct (equiv_dec (` i) (` j0)). - elim H1; destruct i; destruct j0; simpl in *; red in e0; subst; apply index_pf_irrel. - unfold lift2. - generalize (scalarMult_backprop_grad0 σ (x i0 i) s env 0%R); simpl; intros. - unfold lift2. - replace (0 * 0)%R with 0%R in H3 by lra. - unfold df_eval_backprop_delta in H3. - match_option. - match_option. - rewrite eqq0, eqq in H3; unfold lift in H3. - cut_to H3; congruence. - rewrite eqq0 in H3. - apply H3; trivial. - tauto. - congruence. - assert (df_eval_backprop_deriv σ (x i0 i) env 0%R <> None). - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - - split. - generalize (list_env_iter_total_fun - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (UnitMatrix n m i0 j0 i j)) - env (bounded_seq0 m)); intros. - apply H2. - intros. - apply backprop_deriv_fully_closed_not_none; trivial. - split. - + intros. - match_option. - apply (vartlookup_list_env_iter - s (fun (j : {m' : nat | m' < m}) - (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (UnitMatrix n m i0 j0 i j)) - (bounded_seq0 m) env d); trivial. - intros. - apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in H3; trivial. - + split. - * intros. - do 4 match_option. - generalize (list_env_iter_backprop_indep_env_vec - σ (x i) s env env2 - (UnitMatrix n m i0 j0 i) - d d0); simpl; intros. - specialize (H2 eqq eqq0). - specialize (H2 (H0 i)). - rewrite eqq1, eqq2 in H2. - unfold lift in H2. - now inversion H2. - * intros. - replace (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 (UnitMatrix n m i0 j0 i j)) - with - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv σ (x i j) env0 - (scalarMult DTfloat 0%R 0%R)). - case_eq (vartlookup env (s, DTfloat)); [intros | tauto]. - apply scalarMult_backprop_list_env_iter_grad0; trivial. - apply functional_extensionality; intros. - apply functional_extensionality; intros. - f_equal. - unfold scalarMult, UnitMatrix; simpl. - destruct (equiv_dec (` i) (` i0)). - red in e. - elim H2; destruct i; destruct i0; simpl in *; subst. - apply index_pf_irrel. - lra. - Qed. - - Lemma list_env_iter_matvec_delta {m n} (σ:df_env) - (df2:DefinedFunction UnitAnn (DTVector m)) (s: SubVar) grad_env - (i0 : {n' : nat | n' < n}) - (j0 : {m' : nat | m' < m}) (old : float) : - vartlookup grad_env (s, DTfloat) = Some old -> - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df2 vl -> - (lift (fun e => subvar (s, DTfloat) e old) - (df_eval_backprop_deriv - σ df2 grad_env - (UnitVector m j0)) = - (lift (fun e => subvar (s, DTfloat) e old) - (list_env_iter - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv - σ df2 env - ((UnitMatrix n m i0 j0) i)) - (Some grad_env) - (bounded_seq0 n)))). - Proof. - simpl; intros. - generalize (list_env_iter_gen_delta - s grad_env old - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv - σ df2 env - ((UnitMatrix n m i0 j0) i)) - i0). - simpl; intros. - specialize (H1 H). - rewrite <- H1. - - unfold UnitMatrix, UnitVector; simpl. - now destruct (equiv_dec (` i0) (` i0)); [|congruence]. - - clear H1. - intros. - split; [|split]. - + apply backprop_deriv_fully_closed_not_none; auto. - + intros. - match_option. - now apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in eqq;trivial. - + split. - * intros. - do 4 match_option. - generalize (backprop_indep_env σ df2 s env env2 - (UnitMatrix n m i0 j0 i)); simpl; intros HH. - cut_to HH; trivial; try congruence. - unfold df_eval_backprop_delta in HH. - rewrite eqq,eqq0,eqq1,eqq2 in HH. - unfold lift in HH. - now inversion HH. - * intros. - unfold UnitMatrix; simpl. - destruct (equiv_dec (` i) (` i0)). - elim H1; destruct i; destruct i0; simpl in *; red in e; subst; apply index_pf_irrel. - unfold lift2. - generalize (scalarMult_backprop_grad0 σ df2 s env (ConstVector m 0%R)); simpl; intros. - unfold df_eval_backprop_delta in H3. - specialize (H3 H2). - match_option. - match_option; [|tauto]. - -- rewrite eqq0 in H3. - replace (fun i : {n' : nat | n' < m} => (0 * ConstVector m 0 i)%R) with - (fun i : {n' : nat | n' < m} => 0%R) in H3. - ++ rewrite eqq in H3; unfold lift in H3. - apply H3; congruence. - ++ apply functional_extensionality; intros. - unfold ConstVector; simpl. - lra. - -- assert (df_eval_backprop_deriv σ df2 env (fun _ => 0%R) <> None). - apply backprop_deriv_fully_closed_not_none; trivial. - tauto. - Qed. - - Lemma vmap_eta {A B} {n} (f:A->B) (d:Vector A n) : vmap f d = vmap (fun x => f x) d. - Proof. - now apply vmap_ext. - Qed. - - Lemma vsum_eta {n} (d:Vector float n) : vsum d = vsum (fun i => d i). - Proof. - now apply vsum_ext. - Qed. - - Lemma msum_unitvector m n x d1 d0 : - msum - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (UnitVector m x i * d1 i j * d0 j)%R) = - vsum - (fun j : {n' : nat | n' < n} => - (d1 x j * d0 j)%R). - Proof. - unfold msum. - - transitivity ( - vsum - (fun i : {n' : nat | n' < m} => - vsum - (fun (j : {m' : nat | m' < n}) => (UnitVector m x i * d1 i j * d0 j)%R))). - { apply vsum_ext; intros ?. - now rewrite vmap_nth. - } - transitivity ( - vsum - (fun i : {n' : nat | n' < m} => - ((UnitVector m x i * vsum - (fun (j : {m' : nat | m' < n}%nat) => d1 i j * d0 j))%R))). - { - apply vsum_ext; intros ?. - rewrite vsum_mult. - apply vsum_ext; intros ?. - lra. - } - transitivity ( - vsum - (fun i : {n' : nat | n' < m} => - (vsum (fun j : {m' : nat | (m' < n)%nat} => d1 i j * d0 j) * UnitVector m x i)%R)). - { - apply vsum_ext; intros ?; lra. - } - now rewrite vsum_unitvector. - Qed. - - Lemma vsum_as_sum {n} (v:Vector float n) : vsum v = fold_right Rplus R0 (vector_to_list v). - Proof. - rewrite vsum_alt_eq. - unfold vector_to_list, vector_fold_right. - induction n. - - rewrite vector_fold_right_dep_0; trivial. - - repeat rewrite vector_fold_right_dep_Sn. - now rewrite IHn. - Qed. - - Lemma msum_as_sum {m n} (mat:Matrix float m n) : msum mat = fold_right Rplus R0 (matrix_to_list mat). - Proof. - unfold msum, matrix_to_list, matrix_to_list_list. - transitivity - (vsum (fun i => fold_right Rplus R0 (vector_to_list (mat i)))). - { apply vsum_ext; intros [??]. - now rewrite vmap_nth, vsum_as_sum. - } - rewrite vsum_as_sum. - rewrite fold_right_plus_concat. - rewrite map_vector_to_list_vmap. - f_equal. - apply vector_to_list_ext. - intros [??]. - now rewrite vmap_nth. - Qed. - - Lemma transpose_perm_bounded {m n : nat} (mat : Matrix float m n) bound_m pf_m bound_n pf_n: - Permutation - (concat - (vector_fold_right_bounded_dep (fun _ : nat => Datatypes.cons) [] - (fun i : {n' : nat | n' < m} => - vector_fold_right_bounded_dep (fun _ : nat => Datatypes.cons) [] (mat i) bound_n pf_n) bound_m pf_m)) - (concat - (vector_fold_right_bounded_dep (fun _ : nat => Datatypes.cons) [] - (fun i : {n' : nat | n' < n} => - vector_fold_right_bounded_dep (fun _ : nat => Datatypes.cons) [] (transpose mat i) bound_m pf_m) bound_n pf_n)). - Proof. - Hint Constructors Permutation. - revert bound_n pf_n. - induction bound_m; intros; simpl. - - induction bound_n; simpl; trivial. - - rewrite IHbound_m. - clear IHbound_m. - induction bound_n; simpl; trivial. - rewrite <- IHbound_n. - clear IHbound_n. - apply Permutation_cons; trivial. - unfold transpose; simpl. - repeat rewrite <- app_ass. - apply Permutation_app; trivial. - apply Permutation_app_comm. - Qed. - - Lemma transpose_perm {m n} (mat : Matrix float m n) : - Permutation (matrix_to_list mat) (matrix_to_list (transpose mat)). - Proof. - apply transpose_perm_bounded. - Qed. - - Lemma msum_transpose {m n} (mat : Matrix float m n) : - msum mat = msum (transpose mat). - Proof. - repeat rewrite msum_as_sum. - apply fold_right_perm; intros; try lra. - apply transpose_perm. - Qed. - - Ltac match_nested_case := - match goal with - | [|- context[match match ?x with _ => _ end with _ => _ end]] => - let eqq := fresh "eqq" in - case_eq x - ; [intros ? eqq | intros eqq] - end. - - Ltac match_nested_case_in H := - match H with - | context[match match ?x with _ => _ end with _ => _ end] => - let eqq := fresh "eqq" in - case_eq x - ; [intros ? eqq | intros eqq] - ; rewrite eqq in H - end. - - Theorem df_eval_deriv_genvar_same (σ:df_env) (df:DefinedFunction UnitAnn DTfloat) (v:SubVar) : - let vl := map (fun ve => projT1 ve) σ in - is_scalar_function df -> - fully_closed_over df vl -> - let forward := df_eval_deriv_gen_top σ df (v, DTfloat) in - lift transpose_lifted_type forward = df_eval σ (df_deriv df (v,DTfloat)). - Proof. - simpl. - intros is_scalar. - generalize is_scalar. - pattern df. - revert df is_scalar. - DefinedFunction_scalar_cases (apply is_scalar_function_ind) Case; simpl; trivial; intros - ; try - ( cut_to H; trivial - ;rewrite <- H - ;assert (df_eval σ e <> None) by (apply eval_fully_closed_not_none; trivial) - ;unfold lift, df_eval_deriv_gen_top; simpl - ;match_nested_case; [|tauto] - ;now match_nested_case). - - - Case "Var"%string. - match_nested_case. - + red in e. - inversion e; subst. - now refl_simpler; simpl. - + now intros. - - Case "Plus"%string. - destruct H1. - destruct is_scalar. - cut_to H; trivial. - cut_to H0; trivial. - unfold lift, df_eval_deriv_gen_top in *; simpl in *. - match_option_in H. - + match_option_in H0. - * rewrite <- H. - now rewrite <- H0. - * assert ( df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - - Case "Minus"%string. - destruct H1. - destruct is_scalar. - cut_to H; trivial. - cut_to H0; trivial. - unfold lift, df_eval_deriv_gen_top in *; simpl in *. - match_option_in H. - + match_option_in H0. - * rewrite <- H. - now rewrite <- H0. - * assert ( df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - - Case "Times"%string. - destruct H1. - destruct is_scalar. - cut_to H; trivial. - cut_to H0; trivial. - unfold lift, df_eval_deriv_gen_top in *; simpl in *. - match_nested_case; trivial. - match_option_in H. - + match_nested_case. - * match_option_in H0. - -- rewrite <- H. - now rewrite <- H0. - -- assert ( df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - * assert (df_eval σ r <> None) - ; [apply eval_fully_closed_not_none; trivial | tauto]. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - - Case "Divide"%string. - destruct H1. - destruct is_scalar. - cut_to H; trivial. - cut_to H0; trivial. - unfold lift, df_eval_deriv_gen_top in *; simpl in *. - match_nested_case; trivial. - match_option_in H. - + match_nested_case. - * match_option_in H0. - -- rewrite <- H. - now rewrite <- H0. - -- assert ( df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - * assert (df_eval σ r <> None) - ; [apply eval_fully_closed_not_none; trivial | tauto]. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - + assert (df_eval σ l <> None) - ; [apply eval_fully_closed_not_none; trivial | tauto]. - - Case "Sign"%string. - unfold lift, df_eval_deriv_gen_top in *; simpl in *. - match_nested_case; [trivial|]. - assert (df_eval_deriv_genvar σ e [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - - Case "PSign"%string. - unfold lift, df_eval_deriv_gen_top in *; simpl in *. - match_nested_case; [trivial|]. - assert (df_eval_deriv_genvar σ e [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - - Case "Max"%string. - destruct is_scalar; destruct H1. - cut_to H; trivial. - cut_to H0; trivial. - unfold lift, df_eval_deriv_gen_top in *; simpl in *. - assert (df_eval σ l <> None) by (apply eval_fully_closed_not_none; trivial). - assert (df_eval σ r <> None) by (apply eval_fully_closed_not_none; trivial). - match_nested_case; [|tauto]. - match_nested_case; [|tauto]. - match_option_in H. - match_option_in H0. - + rewrite <- H. - rewrite <- H0. - unfold pos_sign; simpl. - case_eq (Rle_dec d d0); intros; f_equal. - destruct (Rge_dec (d0 - d) 0); lra. - destruct (Rge_dec (d0 - d) 0); lra. - + assert (df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto]. - apply eval_deriv_genvar_fully_closed_not_none; trivial. - Qed. - - Lemma yay {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (hs:has_scalar_functions df) (s: SubVar) grad_env : - let v := (s, DTfloat) in - vartlookup grad_env v <> None -> - let vl := map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> - let forward := df_eval_deriv_gen_top σ df v in - let backward := df_eval_backward_gen_top σ df v grad_env in - lift transpose_lifted_type forward = backward. - Proof. - simpl. - intros vin closed. - revert grad_env vin closed. - unfold df_eval_deriv_gen_top, df_eval_backward_gen_top. - pattern T, df. - revert T df hs. - DefinedFunction_cases (apply DefinedFunction_ind_unit_has_scalar_functions) Case - ; simpl; intros. - - Case "Number"%string. - unfold subvar; simpl. - match_destr; [ | tauto]. - f_equal; lra. - - Case "Constant"%string. - unfold subvar; simpl. - match_destr; simpl. - + match_destr; [ | tauto]. - f_equal; lra. - + match_destr; [ | tauto]. - erewrite vectoro_to_ovector_forall_some_b_strong - ; simpl; trivial; intros. - unfold ConstVector. - f_equal; lra. - + match_destr; [ | tauto]. - unfold matrixo_to_omatrix. - repeat (erewrite vectoro_to_ovector_forall_some_b_strong - ; simpl; trivial; intros). - unfold ConstMatrix. - f_equal; lra. - - Case "DVector"%string. - unfold lift, two_vector_env_iter_alt. - case_eq (vartlookup grad_env (s, DTfloat)); [intros|tauto]. - match_option. - + specialize (apply vectoro_to_ovector_forall_some_f eqq); intros. - symmetry. - apply vectoro_to_ovector_forall_some_b_strong; intros. - unfold snd in H. - specialize (H i grad_env); simpl in H. - rewrite vforall_forall in closed. - assert (closedb := closed). - specialize (closed i). - specialize (H vin closed); simpl in H. - rewrite H0 in H. - specialize (H1 i); simpl in H1. - rewrite H1 in H; simpl in H. - unfold lift in H. - match_option_in H. - specialize (apply vectoro_to_ovector_forall_some_f eqq); intros. - specialize (H2 i); simpl in H2. - destruct i; simpl. - unfold UnitVector; simpl. - generalize (list_env_iter_total_fun - (fun (i : {n' : nat | n' < n}) (env : df_env) => - df_eval_backprop_deriv σ (x i) env - (if equiv_dec (` i) x0 then 1%R else 0%R)) - grad_env (bounded_seq0 n)); intros. - cut_to H3. - match_option; [|tauto]. - * rewrite H. - generalize (list_env_iter_vec_delta σ x s grad_env - (exist (fun n' : nat => n' < n) x0 l) d). - intros. - simpl in H4. - specialize (H4 H0 closedb). - rewrite eqq0 in H4. - unfold UnitVector in H4; simpl in H4. - rewrite eqq1 in H4. - unfold lift in H4. - symmetry; trivial. - * intros. - apply backprop_deriv_fully_closed_not_none. - apply closedb. - + specialize (vectoro_to_ovector_exists_None eqq); intros. - destruct H1. - generalize (eval_deriv_genvar_fully_closed_not_none σ (x x0) [mk_env_entry (s, DTfloat) 1%R]); intros. - rewrite vforall_forall in closed. - now specialize (H1 (closed x0)). - - Case "DMatrix"%string. - unfold lift, matrixo_to_omatrix. - case_eq (vartlookup grad_env (s, DTfloat)); [intros|tauto]. - match_option. - + specialize (apply vectoro_to_ovector_forall_some_f eqq); intros. - symmetry. - apply vectoro_to_ovector_forall_some_b_strong; intros. - apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H1 i); simpl in H1. - specialize (apply vectoro_to_ovector_forall_some_f H1); intros. - specialize (H2 i0); simpl in H2. - unfold two_matrix_env_iter_alt. - specialize (H i i0 grad_env); simpl in H. - rewrite vforall_forall in closed. - assert (closedb := closed). - specialize (closed i). - rewrite vforall_forall in closed. - specialize (closed i0). - specialize (H vin closed); simpl in H. - rewrite H2 in H; simpl in H. - rewrite H0 in H; simpl in H. - unfold lift in H. - match_option_in H. - rewrite H. - destruct i. - destruct i0. - unfold lift; simpl. - generalize (list_env_iter_total_fun - (fun (i : {n' : nat | n' < n}) (env : df_env) => - list_env_iter - (fun (j : {m' : nat | m' < m}) (env0 : df_env) => - df_eval_backprop_deriv - σ (x i j) env0 - (UnitMatrix n m (exist (fun n' : nat => n' < n) x0 l) - (exist (fun n' : nat => n' < m) x1 l0) i j)) - (Some env) - (bounded_seq0 m)) - grad_env (bounded_seq0 n)); intros. - cut_to H3. - match_option; [|tauto]. - * generalize (list_env_iter_mat_delta σ x s grad_env - (exist (fun n' : nat => n' < n) x0 l) - (exist (fun n' : nat => n' < m) x1 l0) d). - intros. - simpl in H4. - specialize (H4 H0). - cut_to H4. - rewrite eqq0 in H4. - rewrite eqq1 in H4. - unfold lift in H4. - symmetry; trivial. - intros. - specialize (closedb i). - rewrite vforall_forall in closedb. - apply closedb. - * intros. - apply list_env_iter_total_fun; intros. - apply backprop_deriv_fully_closed_not_none. - specialize (closedb a). - rewrite vforall_forall in closedb. - apply closedb. - + specialize (vectoro_to_ovector_exists_None eqq); intros. - destruct H1. - specialize (vectoro_to_ovector_exists_None e); intros. - destruct H1. - symmetry. - generalize (eval_deriv_genvar_fully_closed_not_none σ (x x0 x1) [mk_env_entry (s, DTfloat) 1%R]); intros. - rewrite vforall_forall in closed. - specialize (closed x0). - rewrite vforall_forall in closed. - now specialize (H1 (closed x1)). - - Case "Var"%string. - unfold equiv_dec, vart_eqdec; simpl. - destruct (vart_dec v (s, DTfloat)). - + destruct v. - inversion e. - subst; simpl. - refl_simpler; simpl. - case_eq (vartlookup grad_env (s, DTfloat)); [intros |tauto]. - simpl; f_equal. - symmetry. - now apply subvar_addvar_scalar_eq. - + case_eq (vartlookup grad_env (s, DTfloat)); [intros|tauto]. - destruct v. - unfold snd. - destruct d0. - * simpl. - unfold lift. - destruct (vartlookup grad_env (s0, DTfloat)). - -- f_equal; simpl; symmetry. - now apply subvar_addvar_scalar_neq. - -- f_equal; symmetry. - unfold subvar; simpl. - rewrite H; lra. - * simpl; symmetry. - apply vectoro_to_ovector_forall_some_b_strong; intros; unfold lift. - destruct (vartlookup grad_env (s0, DTVector n0)). - -- f_equal; simpl. - unfold ConstVector. - now apply subvar_addvar_scalar_neq. - -- f_equal; unfold ConstVector. - unfold subvar; simpl. - rewrite H; lra. - * simpl; symmetry. - unfold matrixo_to_omatrix. - apply vectoro_to_ovector_forall_some_b_strong; intros. - apply vectoro_to_ovector_forall_some_b_strong; intros. - unfold lift. - destruct (vartlookup grad_env (s0, DTMatrix m n0)). - -- f_equal; unfold ConstMatrix. - now apply subvar_addvar_scalar_neq. - -- f_equal; unfold ConstMatrix. - unfold subvar; simpl. - rewrite H; lra. - - Case "Plus"%string. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H. - case_eq (df_eval_backprop_deriv σ l grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - destruct closed. - specialize (H H1). - invcs H. - { specialize (H0 d1). - cut_to H0; - [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))| apply H2]. - match_option - ; rewrite eqq2 in H0 - ; simpl in *. - - match_option_in H0. - unfold lift in H0. - match_option_in H0. - invcs H0. - simpl. - f_equal. - unfold subvar; simpl. - rewrite eqq3. - match_option; lra. - - match_option_in H0; simpl. - + unfold lift in *. - match_option_in H0. - + elim (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat) vin eqq3). - } - + destruct closed. - specialize (H H1). - congruence. - + destruct closed. - specialize (H H1). - congruence. - + destruct closed. - specialize (H H1). - case_eq (df_eval_backprop_deriv σ l grad_env 1%R) - ; [intros ? eqq2 | intros eqq2]. - rewrite eqq2 in H. - simpl in *. - congruence. - rewrite eqq2 in H. - now apply H. - - Case "Minus"%string. - destruct closed. - specialize (H grad_env vin H1). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H. - case_eq (df_eval_backprop_deriv σ l grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - invcs H. - { specialize (H0 d1). - cut_to H0; - [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - | apply H2]. - match_option - ; rewrite eqq2 in H0 - ; simpl in *. - - match_option_in H0. - unfold lift in H0. - match_option_in H0. - invcs H0. - simpl. - f_equal. - unfold subvar; simpl. - rewrite eqq3. - match_option. - + case_eq (df_eval_backprop_deriv σ r d1 (- (1))%R); intros. - * unfold lift. - f_equal. - match_option. - -- generalize (scalarMult_backprop_grad_scalar σ r s d1 d1 1%R (-1)%R) - ; intros; simpl in H0. - cut_to H0. - ++ unfold df_eval_backprop_delta in H0. - rewrite eqq3 in H0; unfold lift in H0; simpl in H0. - replace (-1 * 1)%R with (- (1))%R in H0 by lra. - rewrite H, eqq4 in H0; inversion H0. - unfold subvar in H0; simpl in H0. - rewrite eqq6, eqq5 in H0. - inversion H0. - lra. - ++ congruence. - ++ congruence. - ++ simpl; replace (-1 * 1)%R with (- (1))%R by lra; congruence. - ++ congruence. - -- generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (s,DTfloat)) - ; intros. - cut_to H0; congruence. - * generalize (backprop_deriv_fully_closed_not_none σ r d1 (- (1))%R); intros. - destruct H0; trivial. - + generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s,DTfloat)) - ; intros. - cut_to H; congruence. - - match_option_in H0. - + unfold lift. - unfold lift in H0. - match_option_in H0. - match_option. - specialize (df_eval_backprop_deriv_preserves_lookup_not_none eqq5). - intros. - specialize (H (s, DTfloat)). - generalize (backprop_deriv_fully_closed_not_none σ r d1 (1%R)); intros. - now destruct H3. - + generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s,DTfloat)) - ; intros. - cut_to H; congruence. - } - + unfold lift in H; simpl in H. - match_option_in H. - - Case "Times"%string. - destruct closed. - specialize (H grad_env vin H1). - simpl in *. - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H3 H1). - match_case; [intros|tauto]. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - generalize (eval_fully_closed_not_none σ r); intros. - specialize (H5 H2). - case_eq (df_eval σ r); [intros|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - case_eq (df_eval_backprop_deriv σ l grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - generalize (scalarMult_backprop_grad_scalar σ l s grad_env grad_env 1%R d1) - ; intros; simpl in H7. - unfold df_eval_backprop_delta in H7. - rewrite eqq1 in H7; simpl in H7. - specialize (H7 vin vin). - rewrite eqq0 in H7. - generalize (backprop_deriv_fully_closed_not_none σ l grad_env (d1 * 1)%R); intros. - specialize (H8 H1); specialize (H7 H8). - cut_to H7; try discriminate. - invcs H. - { specialize (H0 d3). - cut_to H0; - [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - | apply H2]. - unfold lift in H7; simpl in H7. - match_option_in H7. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - ;intros. - specialize (H vin). - generalize (backprop_deriv_fully_closed_not_none σ r d3 1%R); intros. - specialize (H9 H2). - case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence]. - rewrite H10 in H0. - unfold lift; simpl. - unfold lift in H0. - match_option_in H0. - - generalize (scalarMult_backprop_grad_scalar σ r s d2 d3 1%R d) - ; intros. - unfold df_eval_backprop_delta in H11; simpl in H11. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat));intros. - specialize (H12 vin). - case_eq (vartlookup d2 (s, DTfloat)); [intros | congruence]. - specialize (H11 H12 H). - rewrite H10, H13 in H11. - generalize (backprop_deriv_fully_closed_not_none σ r d2 (d * 1)%R); intros. - specialize (H14 H2); specialize (H11 H14 H9). - unfold lift in H11. - match_option_in H11; [|congruence]; f_equal. - rewrite (split_subvar d2 d7 d0 d6); trivial. - match_option_in H0; invcs H0. - rewrite eqq5 in H11; invcs H11; invcs H7. - lra. - - now match_option_in H0. - } - unfold lift in H. - match_case_in H; intros. - rewrite H7 in H; congruence. - generalize (backprop_deriv_fully_closed_not_none σ l grad_env 1%R); intros. - now specialize (H8 H1). - - Case "Divide"%string. - destruct closed. - specialize (H grad_env vin H1). - simpl in *. - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H3 H1). - match_case; [intros|tauto]. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - generalize (eval_fully_closed_not_none σ r); intros. - specialize (H5 H2). - case_eq (df_eval σ r); [intros|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - case_eq (df_eval_backprop_deriv σ l grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - generalize (scalarMult_backprop_grad_scalar σ l s grad_env grad_env 1%R (1 / d1)%R) - ; intros; simpl in H7. - unfold df_eval_backprop_delta in H7. - rewrite eqq1 in H7; simpl in H7. - specialize (H7 vin vin). - rewrite eqq0 in H7. - generalize (backprop_deriv_fully_closed_not_none σ l grad_env (1 / d1 * 1)%R); intros. - specialize (H8 H1) ; specialize (H7 H8). - cut_to H7; try discriminate. - replace (1 / d1 * 1)%R with (1/d1)%R in H7 by lra. - invcs H. - { specialize (H0 d3). - cut_to H0; - [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - | apply H2]. - unfold lift in H7; simpl in H7. - match_option_in H7. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - ;intros. - specialize (H vin). - generalize (backprop_deriv_fully_closed_not_none σ r d3 1%R); intros. - specialize (H9 H2). - case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence]. - rewrite H10 in H0. - unfold lift; simpl. - unfold lift in H0. - match_option_in H0. - - generalize (scalarMult_backprop_grad_scalar σ r s d2 d3 1%R (- d / (d1 * d1))%R) - ; intros; simpl in H11. - unfold df_eval_backprop_delta in H11; simpl in H11. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat));intros. - specialize (H12 vin). - case_eq (vartlookup d2 (s, DTfloat)); [intros | congruence]. - specialize (H11 H12 H). - rewrite H10, H13 in H11. - generalize (backprop_deriv_fully_closed_not_none σ r d2 (- d / (d1 * d1) * 1)%R); intros. - specialize (H14 H2); specialize (H11 H14 H9). - unfold lift in H11. - match_option_in H11; [|congruence]; f_equal. - rewrite (split_subvar d2 d7 d0 d6); trivial. - match_option_in H0; invcs H0. - rewrite eqq5 in H11; invcs H11; invcs H7. - lra. - - now match_option_in H0. - } - unfold lift in H. - match_case_in H; intros. - rewrite H7 in H; congruence. - generalize (backprop_deriv_fully_closed_not_none σ l grad_env 1%R); intros. - now specialize (H8 H1). - - Case "Square"%string. - specialize (H grad_env vin closed). - simpl in *. - generalize (eval_fully_closed_not_none σ e); intros. - specialize (H0 closed). - match_case; [intros|tauto]. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - + case_eq (df_eval_backprop_deriv σ e grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (2 * d)%R) - ; intros; simpl in H2. - unfold df_eval_backprop_delta in H2. - rewrite eqq1 in H2; simpl in H2. - specialize (H2 vin vin). - rewrite eqq0 in H2. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env (2 * d * 1)%R); intros. - specialize (H3 closed); specialize (H2 H3). - cut_to H2; try discriminate. - invcs H. - unfold lift in H2; match_option_in H2. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - ;intros. - unfold lift; now rewrite H2. - + unfold lift in H. - match_case_in H; intros. - rewrite H2 in H; congruence. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros. - now specialize (H3 closed). - - Case "Exp"%string. - specialize (H grad_env vin closed). - simpl in *. - generalize (eval_fully_closed_not_none σ e); intros. - specialize (H0 closed). - match_case; [intros|tauto]. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - + case_eq (df_eval_backprop_deriv σ e grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (exp d)%R) - ; intros; simpl in H2. - unfold df_eval_backprop_delta in H2. - rewrite eqq1 in H2; simpl in H2. - specialize (H2 vin vin). - rewrite eqq0 in H2. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env (exp d * 1)%R); intros. - specialize (H3 closed); specialize (H2 H3). - cut_to H2; try discriminate. - invcs H. - replace (1 * exp d)%R with (exp d * 1)%R by lra. - unfold lift in H2; match_option_in H2. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - ;intros. - unfold lift; rewrite H2. - f_equal; lra. - + unfold lift in H. - match_case_in H; intros. - rewrite H2 in H; congruence. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros. - now specialize (H3 closed). - - Case "Log"%string. - specialize (H grad_env vin closed). - simpl in *. - generalize (eval_fully_closed_not_none σ e); intros. - specialize (H0 closed). - match_case; [intros|tauto]. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - + case_eq (df_eval_backprop_deriv σ e grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (1 / d)%R) - ; intros; simpl in H2. - unfold df_eval_backprop_delta in H2. - rewrite eqq1 in H2; simpl in H2. - specialize (H2 vin vin). - rewrite eqq0 in H2. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env (1 / d * 1)%R); intros. - specialize (H3 closed); specialize (H2 H3). - cut_to H2; try discriminate. - invcs H. - replace (1 / d)%R with (1 / d * 1)%R by lra. - unfold lift in H2; match_option_in H2. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - ;intros. - unfold lift; rewrite H2. - f_equal; lra. - + unfold lift in H. - match_case_in H; intros. - rewrite H2 in H; congruence. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros. - now specialize (H3 closed). - - Case "Abs"%string. - specialize (H grad_env vin closed). - simpl in *. - generalize (eval_fully_closed_not_none σ e); intros. - specialize (H0 closed). - match_case; [intros|tauto]. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - + case_eq (df_eval_backprop_deriv σ e grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (sign d)%R) - ; intros; simpl in H2. - unfold df_eval_backprop_delta in H2. - rewrite eqq1 in H2; simpl in H2. - specialize (H2 vin vin). - rewrite eqq0 in H2. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env (sign d * 1)%R); intros. - specialize (H3 closed); specialize (H2 H3). - cut_to H2; try discriminate. - invcs H. - replace (1 * (@sign floatish_R d))%R with ((@sign floatish_R d) * 1)%R by lra. - unfold lift in H2; match_option_in H2. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - ;intros. - unfold lift; rewrite H2. - f_equal; lra. - + unfold lift in H. - match_case_in H; intros. - rewrite H2 in H; congruence. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros. - now specialize (H3 closed). - - Case "Sign"%string. - specialize (H grad_env vin closed). - simpl in *. - generalize (eval_fully_closed_not_none σ e); intros. - specialize (H0 closed). - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - + case_eq (df_eval_backprop_deriv σ e grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (0)%R) - ; intros H1; simpl in H1. - unfold df_eval_backprop_delta in H1. - rewrite eqq1 in H1; simpl in H1. - specialize (H1 vin vin). - rewrite eqq0 in H1. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env (0 * 1)%R); intros. - specialize (H2 closed); specialize (H1 H2). - cut_to H1; try discriminate. - invcs H. - replace (0)%R with (0 * 1)%R by lra. - unfold lift in H1; match_option_in H1. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - ;intros. - unfold lift; rewrite H1. - f_equal; lra. - + unfold lift in H. - match_case_in H; intros. - rewrite H1 in H; congruence. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros. - now specialize (H2 closed). - - Case "PSign"%string. - specialize (H grad_env vin closed). - simpl in *. - generalize (eval_fully_closed_not_none σ e); intros. - specialize (H0 closed). - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - + case_eq (df_eval_backprop_deriv σ e grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (0)%R) - ; intros H1; simpl in H1. - unfold df_eval_backprop_delta in H1. - rewrite eqq1 in H1; simpl in H1. - specialize (H1 vin vin). - rewrite eqq0 in H1. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env (0 * 1)%R); intros. - specialize (H2 closed); specialize (H1 H2). - cut_to H1; try discriminate. - invcs H. - replace (0)%R with (0 * 1)%R by lra. - unfold lift in H1; match_option_in H1. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - ;intros. - unfold lift; rewrite H1. - f_equal; lra. - + unfold lift in H. - match_case_in H; intros. - rewrite H1 in H; congruence. - generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros. - now specialize (H2 closed). - - Case "Max"%string. - destruct closed. - specialize (H grad_env vin H1). - specialize (H0 grad_env vin H2). - simpl in *. - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H3 H1). - match_case; [intros|tauto]. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - generalize (eval_fully_closed_not_none σ r); intros. - specialize (H5 H2). - case_eq (df_eval σ r); [intros|tauto]. - rewrite eqq0 in H. - rewrite eqq0 in H0. - destruct (Rle_dec d d1). - + apply H0. - + apply H. - - Case "VectorDot"%string. - destruct closed. - specialize (H grad_env vin H1). - simpl in *. - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H3 H1). - match_case; [intros|tauto]. - generalize (eval_fully_closed_not_none σ r); intros. - specialize (H5 H2). - case_eq (df_eval σ r); [intros | tauto]. - replace (fun rv : R => (rv * 1)%R) with id. - rewrite vmap_id; rewrite vmap_id. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto]. - symmetry in H. - generalize (backprop_deriv_fully_closed_not_none σ l grad_env d0); intros. - specialize (H7 H1). - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq1 in H. - specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H8. - match_option. - + match_option; [|tauto]. - generalize (backprop_deriv_fully_closed_not_none σ r d4 d); intros. - specialize (H9 H2). - unfold lift. - match_option; [|tauto]. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat));intros. - specialize (H10 vin). - specialize (H0 d4 H10 H2). - rewrite eqq0 in H0. - match_option_in H0. - symmetry in H0. - unfold lift in H0. - specialize (apply vectoro_to_ovector_forall_some_f H0); intros. - simpl in H11. - f_equal. - unfold lift in H8. - rewrite (split_subvar d4 d5 d1 d6); trivial. - rewrite <- vsum_plus. - generalize (df_eval_backprop_delta_by_unit_partvec σ l s grad_env d0 - (fun i => (d2 i * d0 i)%R)); intros. - specialize (H12 H1 vin). - generalize (df_eval_backprop_delta_by_unit_partvec σ r s d4 d - (fun i => (d i * d3 i)%R)); intros. - specialize (H13 H2 H10). - cut_to H12. - cut_to H13. - * unfold df_eval_backprop_delta in H12. - rewrite eqq1 in H12. - unfold df_eval_backprop_delta in H13. - rewrite eqq4 in H13. - unfold lift in H12. - unfold lift in H13. - rewrite eqq2 in H12. - rewrite eqq3 in H13. - invcs H12; invcs H13. - rewrite H14, H15; lra. - * intros. - replace (@scaleUnitVector (@float floatish_R) n i (d i) (IZR Z0)) with - (scalarMult (DTVector n) (d i) (UnitVector n i)). - rewrite scalarMult_backprop_grad_scalar with (grad_env1 := d4) (grad_env2:=d4);trivial. - unfold df_eval_backprop_delta, lift. - rewrite eqq4. - specialize (H11 i). - destruct i; simpl in H11. - simpl_closed_backprop. - f_equal. - rewrite eqq5 in H11. - invcs H11. - rewrite H15. - now unfold scalarMult; simpl. - apply backprop_deriv_fully_closed_not_none; trivial. - apply backprop_deriv_fully_closed_not_none; trivial. - unfold scalarMult, UnitVector, scaleUnitVector; simpl. - apply functional_extensionality; intros. - destruct (equiv_dec (` x) (` i)); lra. - * intros. - replace (@scaleUnitVector (@float floatish_R) n i (d0 i) (IZR Z0)) with - (scalarMult (DTVector n) (d0 i) (UnitVector n i)). - rewrite scalarMult_backprop_grad_scalar with (grad_env1 := grad_env) (grad_env2:=grad_env);trivial. - unfold df_eval_backprop_delta, lift. - rewrite eqq1. - specialize (H8 i). - destruct i; simpl in H8. - simpl_closed_backprop. - f_equal. - rewrite eqq5 in H8. - invcs H8. - rewrite H15. - now unfold scalarMult; simpl; lra. - apply backprop_deriv_fully_closed_not_none; trivial. - apply backprop_deriv_fully_closed_not_none; trivial. - unfold scalarMult, UnitVector, scaleUnitVector; simpl. - apply functional_extensionality; intros. - destruct (equiv_dec (` x) (` i)); lra. - + generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros. - now specialize (H9 H2). - + generalize (eval_deriv_genvar_fully_closed_not_none σ l [mk_env_entry (s, DTfloat) 1%R]); intros. - now specialize (H8 H1). - + apply FunctionalExtensionality.functional_extensionality. - intros; unfold id; lra. - - Case "VectorSum"%string. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H; [|tauto|]. - + specialize (H closed). - symmetry in H. - generalize (vectoro_to_ovector_forall_some_f H); intros HH; clear H. - replace (@lift (@df_env floatish_R) R - (fun e : df_env => subvar (s, DTfloat) e d0) - (df_eval_backprop_deriv σ v grad_env - (ConstVector n 1%R))) - with - (df_eval_backprop_delta σ v (s, DTfloat) grad_env - (ConstVector n 1%R)). - rewrite df_eval_backprop_delta_by_unit_parts with (d1 := d); trivial. - * intros. - specialize (HH i). - unfold df_eval_backprop_delta. - rewrite eqq0. - replace (UnitVector n i) with - (coerce - (df_eval_backward_gen_top_obligation_2 UnitAnn (DTVector n) v n eq_refl i) - (UnitVector n i)); trivial. - destruct i. - now simpl. - * unfold df_eval_backprop_delta. - now rewrite eqq0. - + specialize (H closed). - symmetry in H. - specialize (vectoro_to_ovector_exists_None H); intros. - destruct H0. - unfold lift in e. - match_option_in e. - generalize (backprop_deriv_fully_closed_not_none - σ v grad_env - (coerce - (df_eval_backward_gen_top_obligation_2 - UnitAnn (DTVector n) v n eq_refl x) - (UnitVector n x))); intros. - specialize (H0 closed); tauto. - - Case "MatrixSum"%string. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H; [|tauto|]. - + specialize (H closed). - symmetry in H. - specialize (apply vectoro_to_ovector_forall_some_f H); intros HH; simpl in HH. - replace (@lift (@df_env floatish_R) R - (fun e : df_env => subvar (s, DTfloat) e d0) - (df_eval_backprop_deriv σ v grad_env - (ConstMatrix m n 1%R))) - with - (df_eval_backprop_delta σ v (s, DTfloat) grad_env - (ConstMatrix m n 1%R)). - rewrite df_eval_backprop_delta_by_unit_parts_mat with (d1 := d); trivial. - * intros. - specialize (HH i). - specialize (apply vectoro_to_ovector_forall_some_f HH); intros. - specialize (H0 j); simpl in H0. - unfold df_eval_backprop_delta. - rewrite eqq0. - replace (UnitMatrix m n i j) with - (coerce - (df_eval_backward_gen_top_obligation_3 - UnitAnn (DTMatrix m n) v m n eq_refl i j) - (UnitMatrix m n i j)); trivial. - destruct i; destruct j. - now simpl. - * unfold df_eval_backprop_delta. - now rewrite eqq0. - + specialize (H closed). - symmetry in H. - specialize (vectoro_to_ovector_exists_None H); intros. - destruct H0. - unfold lift in e. - specialize (vectoro_to_ovector_exists_None e); intros. - destruct H0. - match_option_in e0. - generalize (backprop_deriv_fully_closed_not_none - σ v grad_env - (coerce - (df_eval_backward_gen_top_obligation_3 - UnitAnn (DTMatrix m n) v m n eq_refl x x0) - (UnitMatrix m n x x0))); intros. - specialize (H0 closed); tauto. - - Case "VectorElem"%string. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H; [|tauto|]. - + specialize (H closed). - symmetry in H. - specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H. - unfold lift. - replace (fun k : {n' : nat | n' < n} => if equiv_dec (` k) (` i) then 1%R else 0%R) - with - (UnitVector n i). - generalize (backprop_deriv_fully_closed_not_none - σ l grad_env (UnitVector n i)); intros. - specialize (H1 closed). - specialize (H0 i). - match_option; symmetry; [|tauto]. - destruct i; simpl in *. - unfold lift in H0; f_equal. - rewrite eqq1 in H0. - now invcs H0. - unfold UnitVector. - apply FunctionalExtensionality.functional_extensionality; intros. - trivial. - + specialize (H closed). - symmetry in H. - specialize (vectoro_to_ovector_exists_None H); intros. - destruct H0. - unfold lift in e. - match_option_in e. - generalize (backprop_deriv_fully_closed_not_none - σ l grad_env - (coerce - (df_eval_backward_gen_top_obligation_2 - UnitAnn (DTVector n) l n eq_refl x) - (UnitVector n x))); intros. - specialize (H0 closed); tauto. - - Case "MatrixElem"%string. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H; [|tauto|]. - + specialize (H closed). - symmetry in H. - specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H. - unfold lift. - replace - (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) => - if equiv_dec (` k1) (` i) then - if equiv_dec (` k2) (` j) then 1%R else 0%R else 0%R) - with (UnitMatrix m n i j). - generalize (backprop_deriv_fully_closed_not_none - σ l grad_env (UnitMatrix m n i j)); intros. - specialize (H1 closed). - specialize (H0 i). - specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H2. - specialize (H2 j). - match_option; symmetry; [|tauto]. - destruct i; destruct j; simpl in *. - unfold lift in H2; f_equal. - rewrite eqq1 in H2. - now invcs H2. - unfold UnitMatrix. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - trivial. - + specialize (H closed). - symmetry in H. - specialize (vectoro_to_ovector_exists_None H); intros. - destruct H0. - unfold lift in e. - specialize (vectoro_to_ovector_exists_None e); intros. - destruct H0. - match_option_in e0. - generalize (backprop_deriv_fully_closed_not_none - σ l grad_env - (coerce - (df_eval_backward_gen_top_obligation_3 - UnitAnn (DTMatrix m n) l m n eq_refl x x0) - (UnitMatrix m n x x0))); intros. - specialize (H0 closed); tauto. - - Case "MatrixVectorMult"%string. - destruct closed. - specialize (H grad_env vin H1). - simpl in *. - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H3 H1). - match_case; [intros|tauto]. - generalize (eval_fully_closed_not_none σ r); intros. - specialize (H5 H2). - case_eq (df_eval σ r); [intros | tauto]. - assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - match_option; [|tauto]. - assert (df_eval_deriv_genvar σ r [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - match_option; [|tauto]. - match_option; [|tauto]. - unfold lift. - symmetry. - apply vectoro_to_ovector_forall_some_b_strong; intros. - simpl_closed_backprop. - simpl_closed_backprop. - f_equal. - rewrite eqq1, eqq in H. - unfold lift in H; symmetry in H. - specialize (vectoro_to_ovector_forall_some_f H); intros. - specialize (H0 env). - simpler2. - cut_to H0; try congruence. - rewrite eqq0 in H0. - rewrite (split_subvar env env0 d3 val); trivial. - specialize (H9 i); simpl in H9. - specialize (vectoro_to_ovector_forall_some_f H9); intros. - replace (@vsum floatish_R n (fun j : {n' : nat | n' < n} => (d i j * d2 j + d1 i j * d0 j)%R)) - with ((vsum (fun j => (d i j * d2 j)%R)) + (vsum (fun j => (d1 i j * d0 j)%R)))%R - ; [|rewrite vsum_plus; f_equal]. - unfold lift in H0; symmetry in H0. - specialize (vectoro_to_ovector_forall_some_f H0); intros. - simpl in H11; simpl in H10. - destruct i; simpl in eqq2; simpl in eqq3. - unfold matrix_vector_mult in eqq3; simpl in eqq3. - unfold UnitVector in eqq2; simpl in eqq2. - replace (fun i : {n' : nat | n' < n} => - (@vsum floatish_R m - (fun j : {n' : nat | n' < m} => - (d j i * (@UnitVector floatish_R m (exist (fun n' : nat => (n' < m)%nat) x l0) - j)%R)%R))) - with (d (exist (fun n' : nat => (n' < m)%nat) x l0)) in eqq3. - + generalize (df_eval_backprop_delta_by_unit_partvec - σ r s env (d (exist (fun n' : nat => n' < m) x l0 )) - (fun i => ((d (exist (fun n' : nat => (n' < m)%nat) x l0 ) i) * (d2 i))%R)) - ; intros. - specialize (H12 H2). - cut_to H12; try congruence. - * unfold df_eval_backprop_delta in H12. - rewrite eqq3, eqq4 in H12. - unfold lift in H12. - invcs H12. - apply Rplus_eq_compat_l. - generalize (df_eval_backprop_delta_by_unit_partmat - σ l s grad_env - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - ((if equiv_dec (` i) x then 1 else 0) * d0 j)%R) - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - ((if equiv_dec (` i) x then 1 else 0) * (d1 i j) * (d0 j))%R)) - ; intros. - specialize (H12 H1). - cut_to H12; try congruence. - -- unfold df_eval_backprop_delta in H12. - rewrite eqq1 in H12. - unfold lift in H12. - rewrite eqq2 in H12. - invcs H12. - rewrite H15. - replace - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - ((if equiv_dec (` i) x then 1 else 0) * d1 i j * d0 j)%R) with - (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => - (UnitVector m (exist _ x l0) i * d1 i j * d0 j)%R). - now rewrite msum_unitvector. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - now unfold UnitVector; simpl. - -- intros. - rewrite scalarMult_backprop_grad_scalar with (grad_env2:=grad_env) - ;trivial; try congruence. - unfold lift. - unfold df_eval_backprop_delta. - rewrite eqq1. - unfold lift. - specialize (vectoro_to_ovector_forall_some_f H); intros. - specialize (H13 i); simpl in H13. - specialize (vectoro_to_ovector_forall_some_f H13); intros. - specialize (H15 j); intros; simpl in H15. - destruct i; destruct j; simpl in H15. - match_option_in H15. - f_equal; simpl. - destruct (equiv_dec x0 x). - invcs H15. - rewrite H17; lra. - lra. - apply backprop_deriv_fully_closed_not_none; trivial. - apply backprop_deriv_fully_closed_not_none; trivial. - * intros. - replace (@scaleUnitVector (@float floatish_R) n i (d (@exist nat (fun n' : nat => lt n' m) x l0) i) (IZR Z0)) with - (scalarMult (DTVector n) (d (exist (fun n' : nat => n' < m) x l0) i) (UnitVector n i)). - rewrite scalarMult_backprop_grad_scalar with (grad_env1 := env) (grad_env2:=env);trivial; try congruence. - unfold scalarMult; simpl. - unfold lift; simpl. - specialize (H11 i). - destruct i; simpl in H11. - match_option_in H11. - unfold df_eval_backprop_delta. - rewrite eqq4,eqq5. - unfold lift; f_equal. - now invcs H11. - apply backprop_deriv_fully_closed_not_none; trivial. - apply backprop_deriv_fully_closed_not_none; trivial. - unfold scalarMult, scaleUnitVector. - apply FunctionalExtensionality.functional_extensionality; intros. - unfold UnitVector. - destruct (equiv_dec (` x0) (` i)); simpl; lra. - + apply FunctionalExtensionality.functional_extensionality; intros. - now rewrite vsum_unitvector. - - Case "MatrixVectorAdd"%string. - destruct closed. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H; [|tauto|]. - + specialize (H H1). - symmetry in H. - specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H3. - match_option; symmetry. - * apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H3 i). - specialize (apply vectoro_to_ovector_forall_some_f H3); intros; simpl in H4. - apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H4 i0). - unfold lift in H4; unfold lift. - match_option_in H4. - unfold lift in H0. - rewrite eqq1 in H0. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat)) - ;intros. - specialize (H5 vin). - specialize (H0 d2 H5 H2). - match_option_in H0. - symmetry in H0. - specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H6. - specialize (H6 i). - destruct i; destruct i0; simpl; simpl in eqq2; simpl in H6. - rewrite eqq2. - match_option_in H6. - generalize (list_env_iter_total_fun - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv - σ r env - (transpose - (UnitMatrix m n (exist (fun n' : nat => n' < m) x l0) - (exist (fun n' : nat => n' < n) x0 l1)) i)) - d2 (bounded_seq0 n)); intros. - cut_to H7. - -- case_eq (list_env_iter - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv - σ r env - (@transpose R m n - (UnitMatrix m n (exist (fun n' : nat => n' < m) x l0) - (exist (fun n' : nat => n' < n) x0 l1)) i)) - (Some d2) - (bounded_seq0 n)); [intros |tauto]. - f_equal. - rewrite (split_subvar d2 d5 d0 d3); trivial. - invcs H4. - invcs H6. - generalize (list_env_iter_matvec_delta - σ r s d2 - (exist (fun n' : nat => n' < n) x0 l1) - (exist (fun n' : nat => n' < m) x l0) d3). - intros. - specialize (H4 eqq3 H2). - rewrite eqq4 in H4. - replace (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv - σ r env - (UnitMatrix n m (exist (fun n' : nat => n' < n) x0 l1) - (exist (fun n' : nat => n' < m) x l0) i)) with - (fun (i : {m' : nat | m' < n}) (env : df_env) => - df_eval_backprop_deriv - σ r env - (@transpose R m n - (UnitMatrix m n (exist (fun n' : nat => n' < m) x l0) - (exist (fun n' : nat => n' < n) x0 l1)) i)) in H4. - ++ rewrite H8 in H4. - unfold lift in H4. - invcs H4; lra. - ++ apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - f_equal. - unfold transpose, UnitMatrix. - apply FunctionalExtensionality.functional_extensionality; intros. - simpl. - match_case; intros. - match_case; intros. - -- intros. - apply backprop_deriv_fully_closed_not_none; trivial. - * assert (df_eval_deriv_genvar σ r [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - - Case "MatrixMult"%string. - destruct closed. - specialize (H grad_env vin H1). - simpl in *. - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H3 H1). - match_case; [intros|tauto]. - generalize (eval_fully_closed_not_none σ r); intros. - specialize (H5 H2). - case_eq (df_eval σ r); [intros | tauto]. - assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - match_option; [|tauto]. - assert (df_eval_deriv_genvar σ r [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - match_option; [|tauto]. - match_option; [|tauto]. - unfold lift. - symmetry. - apply vectoro_to_ovector_forall_some_b_strong; intros. - apply vectoro_to_ovector_forall_some_b_strong; intros. - simpl_closed_backprop. - simpl_closed_backprop. - f_equal. - rewrite eqq1, eqq in H. - unfold lift in H; symmetry in H. - specialize (vectoro_to_ovector_forall_some_f H); intros. - specialize (H0 env). - simpler2. - cut_to H0; try congruence. - rewrite eqq0 in H0. - rewrite (split_subvar env env0 d3 val); trivial. - specialize (H9 i); simpl in H9. - specialize (vectoro_to_ovector_forall_some_f H9); intros. - replace (@vsum floatish_R p (fun j => (d i j * d2 j i0 + d1 i j * d0 j i0)%R)) - with ((vsum (fun j => (d i j * d2 j i0)%R)) + (vsum (fun j => (d1 i j * d0 j i0)%R)))%R - ; [|rewrite vsum_plus; f_equal]. - unfold lift in H0; symmetry in H0. - specialize (vectoro_to_ovector_forall_some_f H0); intros. - simpl in H10; simpl in H11. - destruct i; destruct i0; simpl in eqq2; simpl in eqq3. - unfold matrix_mult,UnitMatrix in eqq3; simpl in eqq3. - unfold matrix_mult,UnitMatrix in eqq2; simpl in eqq2. - generalize (df_eval_backprop_delta_by_unit_partmat - σ l s grad_env - (fun (i : {n' : nat | n' < m}) (k : {m' : nat | m' < p}) => - vsum - (fun j : {n' : nat | n' < n} => - ((if equiv_dec (` i) x then - if equiv_dec (` j) x0 then 1 else 0 else 0) * d0 k j)%R)) - (fun (i : {n' : nat | n' < m}) (k : {m' : nat | m' < p}) => - vsum - (fun j : {n' : nat | n' < n} => - ((if equiv_dec (` i) x then - if equiv_dec (` j) x0 then 1 else 0 else 0) * - (d1 i k) * d0 k j)%R))) - ; intros. - specialize (H12 H1 vin). - cut_to H12; try congruence. - + unfold df_eval_backprop_delta in H12. - rewrite eqq1 in H12. - unfold lift in H12. - rewrite eqq2 in H12. - invcs H12. - rewrite H14. - assert (msum - (fun (i : {n' : nat | (n' < m)%nat}) (k : {m' : nat | (m' < p)%nat}) => - vsum - (fun j : {n' : nat | (n' < n)%nat} => - (if equiv_dec (` i) x then - if equiv_dec (` j) x0 then 1 else 0 else 0) - * d1 i k * d0 k j))%R = - vsum - (fun j : {n' : nat | (n' < p)%nat} => - d1 (exist (fun n' : nat => (n' < m)%nat) x l0) j * - d0 j (exist (fun n' : nat => (n' < n)%nat) x0 l1))%R). - * rewrite msum_transpose. - unfold msum. - f_equal. - unfold transpose; simpl. - replace (fun (i : {m' : nat | m' < p}) (j : {n' : nat | n' < m}) => - (@vsum floatish_R n - (fun j0 : {n' : nat | n' < n} => - ((if equiv_dec (` j) x then - if equiv_dec (` j0) x0 then 1 else 0 - else 0) * d1 j i * d0 i j0)%R))) with - (fun (i : {m' : nat | m' < p}) (j : {n' : nat | n' < m}) => - if equiv_dec (` j) x then - (d1 (exist _ x l0) i * d0 i (exist _ x0 l1))%R - else 0%R). - -- apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vmap_nth. - replace (fun j : {n' : nat | n' < m} => - if equiv_dec (` j) x - then - (d1 (exist (fun n' : nat => (n' < m)%nat) x l0) x1 * - d0 x1 (exist (fun m' : nat => (m' < n)%nat) x0 l1))%R - else 0%R) with - (fun j : {n' : nat | n' < m} => - ((d1 (exist (fun n' : nat => (n' < m)%nat) x l0) x1 * - d0 x1 (exist (fun m' : nat => (m' < n)%nat) x0 l1))%R * - (@UnitVector floatish_R m (exist _ x l0) j))%R). - now rewrite vsum_unitvector. - apply FunctionalExtensionality.functional_extensionality; intros. - unfold UnitVector; simpl. - destruct (equiv_dec (` x2) x); lra. - -- apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - destruct ( equiv_dec (` x2) x). - ++ replace - (fun j0 : {n' : nat | n' < n} => - ((if equiv_dec (` j0) x0 then 1 else 0) * d1 x2 x1 * d0 x1 j0)%R) - with - (fun j0 : {n' : nat | n' < n} => - ((d1 x2 x1 * d0 x1 j0) * (UnitVector n (exist _ x0 l1) j0))%R). - ** rewrite vsum_unitvector. - red in e. - subst. - destruct x2. - simpl. - erewrite index_pf_irrel; eauto. - ** apply FunctionalExtensionality.functional_extensionality; intros. - unfold UnitVector; simpl. - destruct (equiv_dec (` x3) x0); lra. - ++ rewrite <- vsum_mult. - lra. - * rewrite H12. - apply Rplus_eq_compat_r. - generalize (df_eval_backprop_delta_by_unit_partmat - σ r s env - (fun (i : {n' : nat | n' < p}) (k : {m' : nat | m' < n}) => - vsum - (fun j : {n' : nat | n' < m} => - (d j i * - (if equiv_dec (` j) x then - if equiv_dec (` k) x0 then 1 else 0 else 0))%R)) - (fun (i : {n' : nat | n' < p}) (k : {m' : nat | m' < n}) => - vsum - (fun j : {n' : nat | n' < m} => - (d j i * d2 i k * - (if equiv_dec (` j) x then - if equiv_dec (` k) x0 then 1 else 0 else 0))%R))) - ; intros. - specialize (H13 H2). - cut_to H13; try congruence. - -- unfold df_eval_backprop_delta in H13. - rewrite eqq4 in H13. - unfold lift in H13. - rewrite eqq3 in H13. - invcs H13. - rewrite H16. - unfold msum; f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - rewrite vmap_nth. - replace - (fun k : {m' : nat | m' < n} => - (@vsum floatish_R m - (fun j : {n' : nat | n' < m} => - (d j x1 * d2 x1 k * - (if equiv_dec (` j) x then - if equiv_dec (` k) x0 then 1 else 0 else 0))%R))) with - (fun k : {m' : nat | m' < n} => - (vsum - (fun j => - (d j x1 * d2 x1 k * - (UnitVector m (exist _ x l0) j))%R) - * (UnitVector n (exist _ x0 l1) k))%R). - ++ rewrite vsum_unitvector. - rewrite vsum_unitvector. - lra. - ++ apply FunctionalExtensionality.functional_extensionality; intros. - rewrite Rmult_comm. - rewrite vsum_mult. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - unfold UnitVector; simpl. - destruct (equiv_dec (` x2) x0); destruct (equiv_dec (` x3) x); lra. - -- intros. - rewrite scalarMult_backprop_grad_scalar with (grad_env2:=env); trivial; try congruence. - ++ unfold lift. - unfold df_eval_backprop_delta. - rewrite eqq4. - unfold lift. - specialize (H11 i); simpl in H11. - specialize (vectoro_to_ovector_forall_some_f H11); intros. - specialize (H15 j); intros; simpl in H15. - destruct i; destruct j; simpl in H15. - match_option_in H15. - f_equal; simpl. - invcs H15. - rewrite H17. - rewrite Rmult_comm. - rewrite vsum_mult. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - ++ apply backprop_deriv_fully_closed_not_none; trivial. - ++ apply backprop_deriv_fully_closed_not_none; trivial. - + intros. - rewrite scalarMult_backprop_grad_scalar with (grad_env2:=grad_env); trivial; try congruence. - * unfold lift. - unfold df_eval_backprop_delta. - rewrite eqq1. - unfold lift. - specialize (vectoro_to_ovector_forall_some_f H i); intros; simpl in H11. - specialize (vectoro_to_ovector_forall_some_f H13 j); intros; simpl in H14. - destruct i; destruct j; simpl in H14. - match_option_in H14. - f_equal; simpl. - invcs H14. - rewrite H16. - rewrite Rmult_comm. - rewrite vsum_mult. - f_equal. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - * apply backprop_deriv_fully_closed_not_none; trivial. - * apply backprop_deriv_fully_closed_not_none; trivial. - - Case "VectorPlus"%string. - destruct closed. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H; [|tauto|]. - + specialize (H H1). - symmetry in H. - specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H3. - match_option; symmetry. - * apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H3 i). - unfold lift in H3; unfold lift. - match_option_in H3. - unfold lift in H0. - rewrite eqq1 in H0. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat)) - ;intros. - specialize (H4 vin). - specialize (H0 d2 H4 H2). - match_option_in H0. - symmetry in H0. - specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H5. - specialize (H5 i). - destruct i; simpl in H1; simpl; simpl in eqq2; simpl in H5. - rewrite eqq2. - generalize (backprop_deriv_fully_closed_not_none - σ r d2 - (UnitVector n (exist (fun n' : nat => n' < n) x l0))) - ; intros; specialize (H6 H2). - match_option; [|tauto]; f_equal. - rewrite eqq4 in H5; inversion H5. - rewrite (split_subvar d2 d4 d0 d3); trivial. - inversion H3; lra. - * generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros. - specialize (H4 H2). - tauto. - + specialize (H H1). - symmetry in H. - specialize (apply vectoro_to_ovector_exists_None H); intros. - destruct H3. - unfold lift in e. - match_option_in e. - generalize (backprop_deriv_fully_closed_not_none - σ l grad_env - (coerce - (df_eval_backward_gen_top_obligation_2 - UnitAnn (DTVector n) l n eq_refl x) - (UnitVector n x))); intros. - specialize (H3 H1); tauto. - - Case "VectorMinus"%string. - destruct closed. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H; [|tauto|]. - + specialize (H H1). - symmetry in H. - specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H3. - match_option; symmetry. - * apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H3 i). - unfold lift in H3; unfold lift. - match_option_in H3. - unfold lift in H0. - rewrite eqq1 in H0. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat)) - ;intros. - specialize (H4 vin). - specialize (H0 d2 H4 H2). - match_option_in H0. - symmetry in H0. - specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H5. - specialize (H5 i). - destruct i; simpl in H0; simpl; simpl in eqq2; simpl in H5. - rewrite eqq2. - generalize (backprop_deriv_fully_closed_not_none - σ r d2 - (UnitVector n (exist (fun n' : nat => n' < n) x l0))) - ; intros; specialize (H6 H2). - generalize (scalarMult_backprop_grad_scalar - σ r s d2 d2 - (UnitVector n (exist (fun n' : nat => n' < n) x l0)) - (-1)%R ); intros. - specialize (H7 H4 H4). - generalize (backprop_deriv_fully_closed_not_none - σ r d2 - (scalarMult (DTVector n) (-1)%R - (UnitVector n (exist (fun n' : nat => n' < n) x l0)))) - ; intros; specialize (H8 H2). - specialize (H7 H8 H6). - unfold df_eval_backprop_delta in H7; simpl in H7. - rewrite eqq3 in H7. - unfold lift in H7; simpl in H7. - replace (fun i : {n' : nat | n' < n} => - (- (@UnitVector floatish_R n (exist (fun n' : nat => (n' < n)%nat) x l0)) i)%R) - with - (fun i : {n' : nat | n' < n} => - (-1 * UnitVector n (exist (fun n' : nat => (n' < n)%nat) x l0) i)%R). - match_option; [|tauto]; f_equal. - rewrite eqq4 in H7. - rewrite (split_subvar d2 d4 d0 d3); trivial. - rewrite H5 in H7. - inversion H7. - rewrite H10. - inversion H3; lra. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - * generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros. - specialize (H4 H2). - tauto. - + specialize (H H1). - symmetry in H. - specialize (apply vectoro_to_ovector_exists_None H); intros. - destruct H3. - unfold lift in e. - match_option_in e. - generalize (backprop_deriv_fully_closed_not_none - σ l grad_env - (coerce - (df_eval_backward_gen_top_obligation_2 UnitAnn (DTVector n) l n eq_refl x) - (UnitVector n x))); intros. - specialize (H3 H1); tauto. - - Case "MatrixPlus"%string. - destruct closed. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H; [|tauto|]. - + specialize (H H1). - symmetry in H. - specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H1. - match_option; symmetry. - * apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H3 i). - specialize (apply vectoro_to_ovector_forall_some_f H3); intros; simpl in H4. - apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H4 i0). - unfold lift in H4; unfold lift. - match_option_in H4. - unfold lift in H0. - rewrite eqq1 in H0. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat)) - ;intros. - specialize (H5 vin). - specialize (H0 d2 H5 H2). - match_option_in H0. - symmetry in H0. - specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H6. - specialize (H6 i). - specialize (apply vectoro_to_ovector_forall_some_f H6); intros; simpl in H7. - specialize (H7 i0). - destruct i; destruct i0; simpl; simpl in eqq2; simpl in H7. - rewrite eqq2. - generalize (backprop_deriv_fully_closed_not_none - σ r d2 - (UnitMatrix n m (exist (fun n' : nat => n' < n) x l0) - (exist (fun n' : nat => n' < m) x0 l1))) - ; intros; specialize (H8 H2). - match_option; [|tauto]; f_equal. - rewrite eqq4 in H7; inversion H7. - rewrite (split_subvar d2 d4 d0 d3); trivial. - inversion H4; lra. - * generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros. - now specialize (H4 H2). - + specialize (H H1). - symmetry in H. - specialize (apply vectoro_to_ovector_exists_None H); intros. - destruct H3. - specialize (apply vectoro_to_ovector_exists_None e); intros. - destruct H3. - unfold lift in e0. - match_option_in e0. - generalize (backprop_deriv_fully_closed_not_none - σ l grad_env - (coerce - (df_eval_backward_gen_top_obligation_3 UnitAnn (DTMatrix n m) l n m eq_refl x - x0) (UnitMatrix n m x x0))); intros. - specialize (H3 H1); tauto. - - Case "MatrixMinus"%string. - destruct closed. - specialize (H grad_env vin). - simpl in *. - match_option - ; rewrite eqq in H - ; simpl in *; match_option_in H; [|tauto|]. - + specialize (H H1). - symmetry in H. - specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H1. - match_option; symmetry. - * apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H3 i). - specialize (apply vectoro_to_ovector_forall_some_f H3); intros; simpl in H4. - apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H4 i0). - unfold lift in H4; unfold lift. - match_option_in H4. - unfold lift in H0. - rewrite eqq1 in H0. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat)) - ;intros. - specialize (H5 vin). - specialize (H0 d2 H5 H2). - match_option_in H0. - symmetry in H0. - specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H6. - specialize (H6 i). - specialize (apply vectoro_to_ovector_forall_some_f H6); intros; simpl in H7. - specialize (H7 i0). - destruct i; destruct i0; simpl; simpl in eqq2; simpl in H7. - rewrite eqq2. - generalize (backprop_deriv_fully_closed_not_none - σ r d2 - (UnitMatrix n m (exist (fun n' : nat => n' < n) x l0) - (exist (fun n' : nat => n' < m) x0 l1))) - ; intros; specialize (H8 H2). - - generalize (scalarMult_backprop_grad_scalar - σ r s d2 d2 - (UnitMatrix n m (exist (fun n' : nat => n' < n) x l0) - (exist (fun n' : nat => n' < m) x0 l1)) - (-1)%R ); intros. - specialize (H9 H5 H5). - generalize (backprop_deriv_fully_closed_not_none - σ r d2 - (scalarMult (DTMatrix n m) (-1)%R - (UnitMatrix n m (exist (fun n' : nat => n' < n) x l0) - (exist (fun n' : nat => n' < m) x0 l1)))) - ; intros; specialize (H10 H2). - specialize (H9 H10 H8). - unfold df_eval_backprop_delta in H9; simpl in H9. - rewrite eqq3 in H9. - unfold lift in H9; simpl in H9. - replace - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (- - (@UnitMatrix floatish_R n m (exist (fun n' : nat => (n' < n)%nat) x l0) - (exist (fun n' : nat => (n' < m)%nat) x0 l1)) i j)%R) - with - (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (-1 * - UnitMatrix n m (exist (fun n' : nat => (n' < n)%nat) x l0) - (exist (fun n' : nat => (n' < m)%nat) x0 l1) i j)%R). - match_option; [|tauto]; f_equal. - rewrite eqq4 in H9. - rewrite (split_subvar d2 d4 d0 d3); trivial. - rewrite H7 in H9. - inversion H9. - rewrite H12. - inversion H4; lra. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - lra. - * generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros. - now specialize (H4 H2). - + specialize (H H1). - symmetry in H. - specialize (apply vectoro_to_ovector_exists_None H); intros. - destruct H3. - specialize (apply vectoro_to_ovector_exists_None e); intros. - destruct H3. - unfold lift in e0. - match_option_in e0. - generalize (backprop_deriv_fully_closed_not_none - σ l grad_env - (coerce - (df_eval_backward_gen_top_obligation_3 UnitAnn (DTMatrix n m) l n m eq_refl x - x0) (UnitMatrix n m x x0))); intros. - specialize (H3 H1); tauto. - - Case "VectorScalMult"%string. - destruct closed. - specialize (H grad_env vin H1). - simpl in *. - generalize (eval_fully_closed_not_none σ x); intros. - specialize (H3 H1). - match_case; [intros|tauto]. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H5 H2). - case_eq (df_eval σ l); [intros|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - case_eq (df_eval_backprop_deriv σ x grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - match_option; unfold lift; symmetry. - apply vectoro_to_ovector_forall_some_b_strong; intros. - replace (@vsum - floatish_R n - (fun j : @sig nat (fun n' : nat => lt n' n) => - Rmult (d1 j) - (@coerce (Vector R n) (Vector R n) - (@df_eval_backward_gen_top_obligation_2 - floatish_R UnitAnn - (DTVector n) - (@VectorScalMult floatish_R UnitAnn n ann x l) n - (@eq_refl definition_function_types (DTVector n)) i) - (@UnitVector floatish_R n i) j))) - with (d1 i). - generalize (scalarMult_backprop_grad_scalar - σ x s grad_env grad_env 1%R (d1 i)); intros; simpl in H7. - unfold df_eval_backprop_delta in H7. - rewrite eqq1 in H7; simpl in H7. - specialize (H7 vin vin). - rewrite eqq0 in H7. - generalize (backprop_deriv_fully_closed_not_none - σ x grad_env (d1 i * 1)%R ); intros. - specialize (H8 H1); specialize (H7 H8). - cut_to H7; try discriminate. - invcs H. - { specialize (H0 d3). - cut_to H0; - [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - | apply H2]. - unfold lift in H7; simpl in H7. - match_option_in H7. - replace (d1 i) with (d1 i * 1)%R by lra. - rewrite eqq3. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)); intros. - specialize (H vin). - case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence]. - rewrite H9 in H0. - unfold lift; simpl. - unfold lift in H0. - rewrite eqq2 in H0. - generalize (scalarMult_backprop_grad_scalar σ l s d2 d3 (UnitVector n i) d) - ; intros; simpl in H10. - unfold df_eval_backprop_delta in H10; simpl in H10. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat));intros. - specialize (H11 vin). - case_eq (vartlookup d2 (s, DTfloat)); [intros | congruence]. - specialize (H10 H11 H). - rewrite H9, H12 in H10. - symmetry in H0. - specialize (apply vectoro_to_ovector_forall_some_f H0); intros. - specialize (H13 i); simpl in H13. - generalize (backprop_deriv_fully_closed_not_none - σ l d2 - (fun i0 : {n' : nat | n' < n} => (d * UnitVector n i i0)%R)); intros. - specialize (H14 H2); specialize (H10 H14). - generalize (backprop_deriv_fully_closed_not_none - σ l d3 (UnitVector n i)); intros. - specialize (H15 H2); specialize (H10 H15). - unfold lift in H10. - match_option_in H10; [|congruence]; f_equal. - replace - (fun j : @sig nat (fun n' : nat => lt n' n) => - Rmult d - (@coerce (Vector R n) (Vector R n) - (@df_eval_backward_gen_top_obligation_2 - floatish_R UnitAnn - (DTVector n) (@VectorScalMult floatish_R UnitAnn n ann x l) n - (@eq_refl definition_function_types (DTVector n)) i) - (@UnitVector floatish_R n i) j)) - with - (fun i0 : {n' : nat | n' < n} => (d * UnitVector n i i0)%R). - rewrite eqq4. - rewrite (split_subvar d2 d7 d0 d6); trivial; f_equal. - match_option_in H10; inversion H10. - match_option_in H13. - invcs H7; invcs H13. - rewrite H17. - destruct i; simpl in eqq6. - rewrite eqq6 in eqq5. - invcs eqq5. - lra. - apply FunctionalExtensionality.functional_extensionality; intros. - now destruct i; simpl. - } - + symmetry. - destruct i; simpl. - apply vsum_unitvector. - + specialize (H0 d3). - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)); intros. - specialize (H7 vin). - specialize (H0 H7 H2). - rewrite eqq2 in H0. - unfold lift in H0. - symmetry in H0. - case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence]. - rewrite H8 in H0; simpl in H0. - specialize (apply vectoro_to_ovector_exists_None H0); intros. - destruct H9. - generalize (backprop_deriv_fully_closed_not_none - σ l d3 - (coerce - (df_eval_backward_gen_top_obligation_2 - UnitAnn (DTVector n) l n eq_refl x0) - (UnitVector n x0))); intros. - specialize (H9 H2). - match_option_in e. - tauto. - + unfold lift in H. - generalize (backprop_deriv_fully_closed_not_none - σ x grad_env 1%R); intros. - specialize (H7 H1). - match_option_in H. - tauto. - - Case "MatrixScalMult"%string. - destruct closed. - specialize (H grad_env vin H1). - simpl in *. - generalize (eval_fully_closed_not_none σ x); intros. - specialize (H3 H1). - match_case; [intros|tauto]. - case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto]. - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H5 H2). - case_eq (df_eval σ l); [intros|tauto]. - match_option - ; rewrite eqq in H - ; simpl in *; rewrite eqq0 in H. - case_eq (df_eval_backprop_deriv σ x grad_env 1%R) - ; [intros ? eqq1 | intros eqq1] - ; rewrite eqq1 in H - ; simpl in * - ; try discriminate. - match_option; unfold lift; symmetry. - apply vectoro_to_ovector_forall_some_b_strong; intros. - apply vectoro_to_ovector_forall_some_b_strong; intros. - replace - (@msum floatish_R n m - (fun (i1 : @sig nat (fun n' : nat => lt n' n)) - (j : @sig nat (fun m' : nat => lt m' m)) => - Rmult (d1 i1 j) - (@coerce (Matrix R n m) (Matrix R n m) - (@df_eval_backward_gen_top_obligation_3 floatish_R UnitAnn - (DTMatrix n m) (@MatrixScalMult floatish_R UnitAnn n m ann x l) n m - (@eq_refl definition_function_types (DTMatrix n m)) i i0) - (@UnitMatrix floatish_R n m i i0) i1 j))) - with (d1 i i0). - generalize (scalarMult_backprop_grad_scalar - σ x s grad_env grad_env 1%R (d1 i i0)); intros; simpl in H7. - unfold df_eval_backprop_delta in H7. - - rewrite eqq1 in H7; simpl in H7. - specialize (H7 vin vin). - rewrite eqq0 in H7. - generalize (backprop_deriv_fully_closed_not_none - σ x grad_env (d1 i i0 * 1)%R ); intros. - specialize (H8 H1); specialize (H7 H8). - cut_to H7; try discriminate. - invcs H. - { specialize (H0 d3). - cut_to H0; - [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)) - | apply H2]. - unfold lift in H7; simpl in H7. - match_option_in H7. - replace (d1 i i0) with (d1 i i0 * 1)%R by lra. - rewrite eqq3. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)); intros. - specialize (H vin). - case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence]. - rewrite H9 in H0. - unfold lift in H0. - rewrite eqq2 in H0. - generalize (scalarMult_backprop_grad_scalar σ l s d2 d3 (UnitMatrix n m i i0) d) - ; intros; simpl in H10. - unfold df_eval_backprop_delta in H10; simpl in H10. - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat));intros. - specialize (H11 vin). - case_eq (vartlookup d2 (s, DTfloat)); [intros | congruence]. - specialize (H10 H11 H). - rewrite H9, H12 in H10. - symmetry in H0. - specialize (apply vectoro_to_ovector_forall_some_f H0); intros. - specialize (H13 i); simpl in H13. - specialize (apply vectoro_to_ovector_forall_some_f H13); intros. - specialize (H14 i0); simpl in H14. - generalize (backprop_deriv_fully_closed_not_none - σ l d2 - (fun (i1 : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => - (d * UnitMatrix n m i i0 i1 j)%R)); intros. - specialize (H15 H2); specialize (H10 H15). - generalize (backprop_deriv_fully_closed_not_none - σ l d3 (UnitMatrix n m i i0)); intros. - specialize (H16 H2); specialize (H10 H16). - unfold lift in H10. - match_option_in H10; [|congruence]; f_equal. - replace - (fun (i1 : @sig nat (fun n' : nat => lt n' n)) - (j : @sig nat (fun m' : nat => lt m' m)) => - Rmult - (@coerce (Matrix R n m) (Matrix R n m) - (@df_eval_backward_gen_top_obligation_3 - floatish_R UnitAnn - (DTMatrix n m) (@MatrixScalMult floatish_R UnitAnn n m ann x l) n m - (@eq_refl definition_function_types (DTMatrix n m)) i i0) - (@UnitMatrix floatish_R n m i i0) i1 j) d) - with - (fun i1 j => (d * UnitMatrix n m i i0 i1 j)%R). - rewrite eqq4. - rewrite (split_subvar d2 d7 d0 d6); trivial; f_equal. - match_option_in H10; inversion H10. - match_option_in H14. - rewrite H18. - invcs H7; invcs H14. - destruct i; destruct i0; simpl in eqq6. - rewrite eqq6 in eqq5. - invcs eqq5. - rewrite H17. - lra. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - destruct i; destruct i0; simpl. - lra. - } - + symmetry. - destruct i; destruct i0; simpl. - apply msum_unitmatrix. - + specialize (H0 d3). - generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)); intros. - specialize (H7 vin). - specialize (H0 H7 H2). - rewrite eqq2 in H0. - unfold lift in H0. - symmetry in H0. - case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence]. - rewrite H8 in H0; simpl in H0. - specialize (apply vectoro_to_ovector_exists_None H0); intros. - destruct H9. - specialize (apply vectoro_to_ovector_exists_None e); intros. - destruct H9. - generalize (backprop_deriv_fully_closed_not_none - σ l d3 - (coerce - (df_eval_backward_gen_top_obligation_3 - UnitAnn (DTMatrix n m) l n m eq_refl x0 x1) - (UnitMatrix n m x0 x1))); intros. - specialize (H9 H2). - match_option_in e0. - tauto. - + unfold lift in H. - generalize (backprop_deriv_fully_closed_not_none - σ x grad_env 1%R); intros. - specialize (H7 H1). - match_option_in H. - tauto. - - Case "VectorApply"%string. - destruct closed. - generalize (eval_fully_closed_not_none (mk_env_entry (v, DTfloat) (0%R) :: nil) s0); intros. - specialize (H4 H2). - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H5 H3). - match_case; [intros|tauto]. - specialize (H1 grad_env vin H3). - case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto]. - rewrite eqq1 in H1; simpl in H1. - match_option - ; rewrite eqq in H1; unfold lift in H1; symmetry in H1. - + specialize (apply vectoro_to_ovector_forall_some_f H1); intros; simpl in H7. - unfold lift; simpl. - match_option. - * specialize (vectoro_to_ovector_forall_some_f eqq0); intros; simpl in H8. - symmetry. - apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H8 i). - specialize (H7 i). - destruct i. - simpl. - match_nested_case. - -- assert ( df_eval_backprop_deriv σ l grad_env v1 <> None). - apply backprop_deriv_fully_closed_not_none; trivial. - match_option; [|tauto]. - f_equal. - simpl in H7. - match_option_in H8. - specialize (vectoro_to_ovector_forall_some_f eqq2); intros. - assert (H10c := H10). - specialize (H10 (exist _ x l0)). - rewrite vmap_nth in H10; simpl in H10. - match_option_in H10; simpl in H10. - unfold UnitVector in H10; simpl in H10. - destruct (equiv_dec x x); [|congruence]. - invcs H8; rewrite H12. - assert (v1 = - scalarMult (DTVector n) d4 - (UnitVector n (exist (fun n' : nat => n' < n) x l0))). - ++ unfold scalarMult; simpl. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H10c x0); simpl in H10c. - rewrite vmap_nth in H10c; simpl in H10c. - match_option_in H10c. - invcs H10c. - destruct x0. - unfold UnitVector; simpl. - destruct (equiv_dec x0 x). - ** red in e0. - subst. - erewrite index_pf_irrel in eqq6. - rewrite eqq5 in eqq6. - invcs eqq6; lra. - ** lra. - ++ replace (1 * d4)%R with d4 in H10 by lra. - generalize (scalarMult_backprop_grad_scalar - σ l s grad_env grad_env - (UnitVector n (exist (fun n' : nat => n' < n) x l0)) d4) - ; intros. - simpl in H11; cut_to H11; trivial; try congruence. - ** unfold df_eval_backprop_delta in H11. - rewrite eqq1 in H11. - unfold lift in H11; simpl in H11. - rewrite H8 in eqq3. - unfold scalarMult in eqq3; simpl in eqq3. - match_option_in H7. - rewrite eqq3, eqq6 in H11. - invcs H11; rewrite H14. - invcs H7; rewrite H11. - generalize - (df_eval_deriv_genvar_same - [mk_env_entry (v, DTfloat) - (d (exist (fun n' : nat => n' < n) x l0) )] - s0 v); simpl; intros. - specialize (H7 H H2). - unfold lift, df_eval_deriv_gen_top in H7; simpl in H7. - rewrite <- H12. - unfold mk_genvar_env in eqq4; simpl in eqq4. - rewrite eqq4, eqq5 in H7. - invcs H7; lra. - ** apply backprop_deriv_fully_closed_not_none; trivial. - ** apply backprop_deriv_fully_closed_not_none; trivial. - -- apply vectoro_to_ovector_exists_None in eqq2; destruct eqq2. - rewrite vmap_nth in e; simpl in e. - assert (df_eval [mk_env_entry (v, DTfloat) (d x0)] - (df_deriv s0 (v, DTfloat)) <> None). - apply eval_fully_closed_not_none. - apply fully_closed_deriv; trivial. - match_option_in e; tauto. - * apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0. - assert (df_eval_deriv_genvar [mk_env_entry (v, DTfloat) (d x)] s0 - [mk_env_entry (v, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - match_option_in e; tauto. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - - Case "MatrixApply"%string. - destruct closed. - generalize (eval_fully_closed_not_none (mk_env_entry (v, DTfloat) (0%R) :: nil) s0); intros. - specialize (H4 H2). - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H5 H3). - match_case; [intros|tauto]. - specialize (H1 grad_env vin H3). - case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto]. - rewrite eqq1 in H1; simpl in H1. - match_option - ; rewrite eqq in H1; unfold lift in H1; symmetry in H1. - + specialize (apply vectoro_to_ovector_forall_some_f H1); intros; simpl in H7. - unfold lift; simpl. - match_option. - * specialize (apply vectoro_to_ovector_forall_some_f eqq0); intros; simpl in H8. - symmetry. - apply vectoro_to_ovector_forall_some_b_strong; intros. - apply vectoro_to_ovector_forall_some_b_strong; intros. - specialize (H8 i). - specialize (H7 i). - match_nested_case. - -- assert ( df_eval_backprop_deriv σ l grad_env m1 <> None). - apply backprop_deriv_fully_closed_not_none; trivial. - match_option; [|tauto]. - f_equal. - simpl in H7. - specialize (vectoro_to_ovector_forall_some_f H8); intros. - specialize (vectoro_to_ovector_forall_some_f H7); intros. - specialize (H10 i0); simpl in H10. - match_option_in H10. - specialize (vectoro_to_ovector_forall_some_f eqq2); intros. - assert (H12c := H12). - specialize (H12 i); simpl in H12. - specialize (vectoro_to_ovector_forall_some_f H12); intros. - specialize (H13 i0); simpl in H13; unfold mmap in H13. - rewrite vmap_nth in H13; simpl in H13. - rewrite vmap_nth in H13; simpl in H13. - destruct i; destruct i0; simpl in H13. - simpl in *. - unfold matrix_zip in H13. - rewrite vmap_nth in H13; simpl in H13. - match_option_in H13; simpl in H13. - unfold UnitMatrix in H13; simpl in H13. - destruct (equiv_dec x x); [|congruence]. - destruct (equiv_dec x0 x0); [|congruence]. - invcs H13; invcs H10. - assert (m1 = - scalarMult (DTMatrix m n) d4 - (UnitMatrix m n - (exist (fun n' : nat => n' < m) x l0) - (exist (fun n' : nat => n' < n) x0 l1))). - ++ unfold scalarMult; simpl. - apply FunctionalExtensionality.functional_extensionality; intros. - apply FunctionalExtensionality.functional_extensionality; intros. - specialize (H12c x1); simpl in H12c. - specialize (vectoro_to_ovector_forall_some_f H12c); intros. - specialize (H10 x2); simpl in H10; unfold mmap in H10. - - rewrite vmap_nth in H10; simpl in H10. - rewrite vmap_nth in H10; simpl in H10. - unfold matrix_zip in H10. - rewrite vmap_nth in H10; simpl in H10. - match_option_in H10. - invcs H10. - destruct x1; destruct x2. - unfold UnitMatrix; simpl. - destruct (equiv_dec x1 x). - ** red in e1; subst. - destruct (equiv_dec x2 x0). - --- red in e1; subst. - rewrite index_pf_irrel with (pf2 := l0) in eqq6. - rewrite index_pf_irrel with (pf1 := l3) (pf2 := l1) in eqq6. - rewrite eqq5 in eqq6. - invcs eqq6; lra. - --- lra. - ** lra. - ++ replace (1 * d4)%R with d4 in H15 by lra. - generalize (scalarMult_backprop_grad_scalar - σ l s grad_env grad_env - (UnitMatrix m n (exist (fun n' : nat => n' < m) x l0) - (exist (fun n' : nat => n' < n) x0 l1)) d4) - ; intros. - simpl in H13; cut_to H13; trivial; try congruence. - ** unfold df_eval_backprop_delta in H13. - rewrite eqq1 in H13. - unfold lift in H13; simpl in H13. - rewrite H10 in eqq3. - unfold scalarMult in eqq3; simpl in eqq3. - specialize (H11 (exist (fun n' : nat => n' < n) x0 l1)). - simpl in H11. - match_option_in H11. - rewrite eqq3, eqq6 in H13. - - generalize - (df_eval_deriv_genvar_same - [mk_env_entry (v, DTfloat) - (d (exist (fun n' : nat => n' < m) x l0) - (exist (fun n' : nat => n' < n) x0 l1))] - s0 v). - simpl; intros. - specialize (H16 H H2). - unfold lift, df_eval_deriv_gen_top in H16; simpl in H16. - unfold mk_genvar_env in eqq4; simpl in eqq4. - rewrite eqq4, eqq5 in H16. - invcs H13; rewrite H18. - invcs H16. - invcs H11; rewrite H15; lra. - ** apply backprop_deriv_fully_closed_not_none; trivial. - ** apply backprop_deriv_fully_closed_not_none; trivial. - -- unfold matrixo_to_omatrix in eqq2. - apply vectoro_to_ovector_exists_None in eqq2; destruct eqq2. - apply vectoro_to_ovector_exists_None in e; destruct e. - unfold mmap in e. - rewrite vmap_nth in e; simpl in e. - rewrite vmap_nth in e; simpl in e. - rewrite matrix_zip_m_n in e. - destruct i; destruct i0. - simpl in e. - assert (df_eval [mk_env_entry (v, DTfloat) - (d x x0)] - (df_deriv s0 (v, DTfloat)) <> None). - apply eval_fully_closed_not_none. - apply fully_closed_deriv; trivial. - match_option_in e; tauto. - * unfold matrixo_to_omatrix in eqq0. - apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0. - apply vectoro_to_ovector_exists_None in e; destruct e. - assert (df_eval_deriv_genvar [mk_env_entry (v, DTfloat) (d x x0)] s0 - [mk_env_entry (v, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - match_option_in e; tauto. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - - Case "VLossfun"%string. - destruct closed. - generalize (eval_fully_closed_not_none - (mk_env_entry (v1, DTfloat) (0%R) :: - (mk_env_entry (v2, DTfloat) (0%R) :: nil)) s0); intros. - specialize (H4 H2). - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H5 H3). - match_case; [intros|tauto]. - specialize (H1 grad_env vin H3). - case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto]. - rewrite eqq1 in H1; simpl in H1. - match_option - ; rewrite eqq in H1; unfold lift in H1; symmetry in H1. - + specialize (vectoro_to_ovector_forall_some_f H1); intros; simpl in H7. - unfold lift; simpl. - match_nested_case. - * specialize (vectoro_to_ovector_forall_some_f eqq0); intros; simpl in H8. - symmetry. - match_nested_case. - -- specialize (vectoro_to_ovector_forall_some_f eqq2); intros; simpl in H9. - assert ( df_eval_backprop_deriv σ l grad_env v0 <> None). - apply backprop_deriv_fully_closed_not_none; trivial. - match_option; [|tauto]. - unfold snd in d1; simpl in d1. - assert (forall i : {n' : nat | n' < n}, - df_eval [mk_env_entry (v1, DTfloat) (d i); mk_env_entry (v2, DTfloat) (r i)] - (df_deriv s0 (v1, DTfloat)) - = Some (v0 i)). - intros ; specialize (H9 i) - ; rewrite vmap_nth in H9 - ; simpl in H9 - ; match_option_in H9 - ;inversion H9 ;f_equal; lra. - generalize (df_eval_backprop_delta_by_unit_partvec σ l s grad_env v0 v). - intros. - specialize (H12 H3 vin). - cut_to H12. - ++ unfold df_eval_backprop_delta, lift in H12. - now rewrite eqq1, eqq3 in H12. - ++ intros. - replace (@scaleUnitVector (@float floatish_R) n i (v0 i) (IZR Z0)) with - (scalarMult (DTVector n) (v0 i) (UnitVector n i)) by - (unfold scalarMult, scaleUnitVector, UnitVector; - apply FunctionalExtensionality.functional_extensionality; intros; - simpl; destruct (equiv_dec (` x) (` i)); lra). - rewrite scalarMult_backprop_grad_scalar with (grad_env2 := grad_env); trivial. - ** unfold df_eval_backprop_delta. - rewrite eqq1. - specialize (H7 i); simpl in H7. - specialize (H8 i); simpl in H8. - specialize (H9 i); simpl in H9. - generalize - (df_eval_deriv_genvar_same - [mk_env_entry (v1, DTfloat) (d i); mk_env_entry (v2, DTfloat) (r i)] - s0 v1); simpl; intros. - specialize (H13 H H2). - unfold df_eval_deriv_gen_top in H13; simpl in H13. - unfold lift in H13; simpl in H13. - match_option_in H13. - --- unfold mk_genvar_env in H8; simpl in H8. - rewrite eqq4 in H8. - specialize (H11 i). - destruct i; simpl in H7. - match_option_in H7. - unfold lift; f_equal. - unfold scalarMult; simpl. - invcs H8. - invcs H7. - rewrite H14. - rewrite <- H13 in H11. - invcs H11. - lra. - --- assert (df_eval_deriv_genvar - [mk_env_entry (v1, DTfloat) (d i); - mk_env_entry (v2, DTfloat) (r i)] s0 - [mk_env_entry (v1, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - ** apply backprop_deriv_fully_closed_not_none; trivial. - ** apply backprop_deriv_fully_closed_not_none; trivial. - -- apply vectoro_to_ovector_exists_None in eqq2. - destruct eqq2. - rewrite vmap_nth in e; simpl in e. - match_option_in e. - assert (df_eval [mk_env_entry (v1, DTfloat) (d x); - mk_env_entry (v2, DTfloat) (r x)] - (df_deriv s0 (v1, DTfloat)) <> None). - apply eval_fully_closed_not_none; simpl. - apply fully_closed_deriv; trivial. - tauto. - * apply vectoro_to_ovector_exists_None in eqq0. - destruct eqq0. - match_option_in e. - assert (df_eval_deriv_genvar - [mk_env_entry (v1, DTfloat) (d x); - mk_env_entry (v2, DTfloat) (r x)] s0 - [mk_env_entry (v1, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - - Case "MLossfun"%string. - destruct closed. - generalize (eval_fully_closed_not_none - (mk_env_entry (v1, DTfloat) (0%R) :: - (mk_env_entry (v2, DTfloat) (0%R) :: nil)) s0); intros. - specialize (H4 H2). - generalize (eval_fully_closed_not_none σ l); intros. - specialize (H5 H3). - match_case; [intros|tauto]. - specialize (H1 grad_env vin H3). - case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto]. - rewrite eqq1 in H1; simpl in H1. - match_option - ; rewrite eqq in H1; unfold lift in H1; symmetry in H1. - + specialize (vectoro_to_ovector_forall_some_f H1); intros; simpl in H7. - unfold lift; simpl. - match_nested_case. - * specialize (vectoro_to_ovector_forall_some_f eqq0); intros; simpl in H8. - symmetry. - match_nested_case. - -- specialize (vectoro_to_ovector_forall_some_f eqq2); intros; simpl in H9. - assert ( df_eval_backprop_deriv σ l grad_env m1 <> None). - apply backprop_deriv_fully_closed_not_none; trivial. - match_option; [|tauto]. - unfold snd in d1; simpl in d1. - assert (forall (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}), - match - df_eval [mk_env_entry (v1, DTfloat) (d i j); - mk_env_entry (v2, DTfloat) (r i j)] - (df_deriv s0 (v1, DTfloat)) - with - | Some se => Some (1 * se / IZR (Z.of_nat n))%R - | None => None - end = Some (m1 i j)); intros. - ++ specialize (H9 i). - specialize (vectoro_to_ovector_forall_some_f H9); intros. - specialize (H11 j); simpl in H11; unfold mmap in H11. - do 2 rewrite vmap_nth in H11; simpl in H11. - unfold matrix_zip, vector_zip in H11; simpl in H11. - rewrite vmap_nth in H11; simpl in H11. - match_option; rewrite eqq4 in H11; apply H11. - ++ generalize (df_eval_backprop_delta_by_unit_partmat - σ l s grad_env m1 - (mmap (fun u => u / IZR (Z.of_nat n))%R m0)); intros. - specialize (H12 H3 vin). - cut_to H12. - ** unfold df_eval_backprop_delta, lift in H12. - rewrite eqq1, eqq3 in H12. - replace ((@msum floatish_R m n m0) / IZR (Z.of_nat n))%R with - (msum (mmap (fun u : R => (u / IZR (Z.of_nat n))%R) m0)). - --- apply H12. - --- now rewrite msum_mmap_div_denom. - ** intros. - rewrite scalarMult_backprop_grad_scalar with (grad_env2 := grad_env) - ; trivial. - --- unfold df_eval_backprop_delta. - rewrite eqq1. - specialize (H7 i); simpl in H7. - specialize (H8 i); simpl in H8. - specialize (H11 i j); simpl in H11. - specialize (vectoro_to_ovector_forall_some_f H7); intros. - specialize (H13 j); simpl in H13. - specialize (vectoro_to_ovector_forall_some_f H8); intros. - specialize (H14 j); simpl in H14. - generalize - (df_eval_deriv_genvar_same - [mk_env_entry (v1, DTfloat) (d i j); - mk_env_entry (v2, DTfloat) (r i j)] - s0 v1); simpl; intros. - specialize (H15 H H2). - unfold df_eval_deriv_gen_top in H15; simpl in H15. - unfold lift in H15; simpl in H15. - match_option_in H15. - +++ unfold mk_genvar_env in H14; simpl in H14. - rewrite eqq4 in H14. - destruct i; destruct j; simpl in H13. - match_option_in H13. - unfold lift; f_equal. - invcs H13. - invcs H14. - rewrite H17. - unfold mmap. - rewrite vmap_nth. - rewrite vmap_nth. - rewrite <- H16. - rewrite <- H15 in H11. - invcs H11. - lra. - +++ assert (df_eval_deriv_genvar - [mk_env_entry (v1, DTfloat) (d i j); - mk_env_entry (v2, DTfloat) (r i j)] s0 - [mk_env_entry (v1, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - --- apply backprop_deriv_fully_closed_not_none; trivial. - --- apply backprop_deriv_fully_closed_not_none; trivial. - -- unfold matrixo_to_omatrix in eqq2. - apply vectoro_to_ovector_exists_None in eqq2. - destruct eqq2. - apply vectoro_to_ovector_exists_None in e. - destruct e. - unfold mmap in e. - rewrite vmap_nth in e; simpl in e. - rewrite vmap_nth in e; simpl in e. - unfold matrix_zip,vector_zip in e; simpl in e. - rewrite vmap_nth in e. - match_option_in e. - assert (df_eval [mk_env_entry (v1, DTfloat) (d x x0); - mk_env_entry (v2, DTfloat) (r x x0)] - (df_deriv s0 (v1, DTfloat)) <> None). - apply eval_fully_closed_not_none; simpl. - apply fully_closed_deriv; trivial. - tauto. - * unfold matrixo_to_omatrix in eqq0. - apply vectoro_to_ovector_exists_None in eqq0. - destruct eqq0. - apply vectoro_to_ovector_exists_None in e. - destruct e. - match_option_in e. - assert (df_eval_deriv_genvar - [mk_env_entry (v1, DTfloat) (d x x0); - mk_env_entry (v2, DTfloat) (r x x0)] s0 - [mk_env_entry (v1, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None). - apply eval_deriv_genvar_fully_closed_not_none; trivial. - tauto. - Qed. - -(* -Tactic Notation "DefinedFunction_cases" tactic(first) ident(c) := - first; - [ Case_aux c "Number"%string - | Case_aux c "Constant"%string - | Case_aux c "DVector"%string - | Case_aux c "DMatrix"%string - | Case_aux c "Var"%string - | Case_aux c "Plus"%string - | Case_aux c "Minus"%string - | Case_aux c "Times"%string - | Case_aux c "Divide"%string - | Case_aux c "Square"%string - | Case_aux c "Exp"%string - | Case_aux c "Log"%string - | Case_aux c "Abs"%string - | Case_aux c "Sign"%string - | Case_aux c "PSign"%string - | Case_aux c "Max"%string - | Case_aux c "VectorDot"%string - | Case_aux c "VectorSum"%string - | Case_aux c "MatrixSum"%string - | Case_aux c "VectorElem"%string - | Case_aux c "MatrixElem"%string - | Case_aux c "MatrixVectorMult"%string - | Case_aux c "MatrixVectorAdd"%string - | Case_aux c "MatrixMult"%string - | Case_aux c "VectorPlus"%string - | Case_aux c "VectorMinus"%string - | Case_aux c "MatrixPlus"%string - | Case_aux c "MatrixMinus"%string - | Case_aux c "VectorScalMult"%string - | Case_aux c "MatrixScalMult"%string - | Case_aux c "VectorApply"%string - | Case_aux c "MatrixApply"%string - | Case_aux c "VLossfun"%string - | Case_aux c "MLossfun"%string]. - - - Lemma tree_backpropeq_complete_gen {T} (env gradenv : df_env) - (dfexpr : DefinedFunction EvalAnn T) (grad : definition_function_types_interp T) : - forall (x : SubVar), - let xvar := (x, DTfloat) in - vartlookup gradenv (x,DTfloat) <> None -> - match df_eval_tree_deriv env dfexpr xvar, - backprop_lookup (Some gradenv) xvar, - backprop_lookup (df_eval_tree_backprop_deriv env dfexpr gradenv grad) xvar - with - | Some dval, Some bval0, Some bval1 => (dval*grad + bval0)%R = bval1 - | None, _, None => True - | _, _, _ => False - end. - Proof. - simpl. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case. - - - Case "Number"%string. - intros _ _ grad gradenv xinn. - destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra. - - Case "Constant"%string. - intros _ _ grad gradenv xinn. - destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra. - - Case "Var"%string. - intros sv _ grad gradenv xinn. - case_eq (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]. - destruct (var_dec x sv); simpl. - + subst. - rewrite H; simpl. - rewrite lookup_update. - destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (sv, DTfloat)); [| congruence]. - unfold addvar; simpl. - rewrite H. - lra. - + destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (x, DTfloat)); [congruence | ]. - case_eq (vartlookup gradenv (sv, DTfloat)); simpl; intros. - * rewrite lookup_update_neq by congruence. - rewrite H. - lra. - * rewrite H. - lra. - - Case "Plus"%string. - intros _ l r IHl IHr grad gradenv xinn. - case_eq (df_eval_tree_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_tree_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl grad gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env l gradenv grad) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr grad ge'). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' grad) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_tree_backprop_deriv env l gradenv grad ); simpl; trivial; intros. - apply IHr. - apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl grad gradenv xinn). - case_eq (df_eval_tree_backprop_deriv env l gradenv grad); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Minus"%string. - intros _ l r IHl IHr grad gradenv xinn. - case_eq (df_eval_tree_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_tree_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl grad gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env l gradenv grad) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr (- grad)%R ge'). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (- grad)%R) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_tree_backprop_deriv env l gradenv grad ); simpl; trivial; intros. - apply IHr. - apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl grad gradenv xinn). - case_eq (df_eval_tree_backprop_deriv env l gradenv grad); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Times"%string. - intros _ l r IHl IHr grad gradenv xinn. - case_eq (df_eval_tree_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_tree_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl (get_annotation r * grad)%R gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr (get_annotation l * grad)%R ge'). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (get_annotation l * grad)%R) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R ); simpl; trivial; intros. - apply IHr. - apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl (get_annotation r * grad)%R gradenv xinn). - case_eq (df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Divide"%string. - intros _ l r IHl IHr grad gradenv xinn. - case_eq (df_eval_tree_deriv env l (x, DTfloat)) - ; [intros dl eqdl | intros eqdl] - ; rewrite eqdl in IHl. - + case_eq (df_eval_tree_deriv env r (x, DTfloat)) - ; [intros dr eqdr | intros eqdr] - ; rewrite eqdr in IHr. - * specialize (IHl (grad / get_annotation r)%R gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHl - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHl - ; [ | tauto]. - simpl in IHl. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHl - ; [ | tauto]. - specialize (IHr (- get_annotation l / ((get_annotation r) * (get_annotation r)) * grad)%R ge'). - cut_to IHr; [ | congruence ]. - rewrite xv'eqq in IHr. - destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (- get_annotation l / ((get_annotation r) * (get_annotation r)) * grad)%R) (x, DTfloat)); trivial. - lra. - * case_eq ( df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r)%R ); simpl; trivial; intros. - apply IHr. - apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn). - + specialize (IHl (grad / get_annotation r)%R gradenv xinn). - case_eq (df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r )%R); simpl; trivial; intros. - rewrite H in IHl. - simpl in IHl. - generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros. - destruct (vartlookup d (x, DTfloat)); tauto. - - Case "Square"%string. - intros _ e IHe grad gradenv xinn. - - specialize (IHe (2 * (get_annotation e) * grad)%R gradenv xinn). - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env e gradenv (2 * (get_annotation e) * grad)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - Case "Exp"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe (grad * exp (get_annotation e))%R gradenv xinn). - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env e gradenv (grad * exp (get_annotation e))%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - - Case "Log"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe (grad / get_annotation e)%R gradenv xinn). - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env e gradenv (grad / get_annotation e)%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - - Case "Abs"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe (grad * (sign (get_annotation e)))%R gradenv xinn). - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_backprop_deriv env e gradenv (grad * (sign (get_annotation e)))%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - case_eq (vartlookup ge' (x, DTfloat)) - ; [intros xv' xv'eqq | intros xv'eqq] - ; rewrite xv'eqq in IHe - ; [ | tauto]. - simpl. - rewrite xv'eqq. - lra. - - Case "Sign"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe 0%R gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (df_eval_tree_backprop_deriv env e gradenv 0%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - simpl. - replace (de * 0)%R with (0)%R in IHe by lra. - replace (0 * grad)%R with (0)%R by lra. - apply IHe. - - Case "PSign"%string. - intros _ e IHe grad gradenv xinn. - specialize (IHe 0%R gradenv xinn). - case_eq (vartlookup gradenv (x, DTfloat)) - ; [intros xv xveqq | intros xveqq] - ; rewrite xveqq in IHe - ; [ | tauto]. - case_eq (df_eval_tree_deriv env e (x, DTfloat)) - ; [intros de eqde | intros eqde] - ; rewrite eqde in IHe - ; trivial. - case_eq (df_eval_tree_backprop_deriv env e gradenv 0%R) - ; [intros ge' ge'eq | intros ge'eq] - ; rewrite ge'eq in IHe - ; [ | tauto]. - simpl in IHe. - simpl. - replace (de * 0)%R with (0)%R in IHe by lra. - replace (0 * grad)%R with (0)%R by lra. - apply IHe. - - Case "Max"%string. - intros _ l r IHl IHr grad gradenv xinn. - specialize (IHl grad gradenv xinn). - specialize (IHr grad gradenv xinn). - destruct (Rle_dec (get_annotation l) (get_annotation r)); simpl. - destruct (df_eval_tree_deriv env r (x, DTfloat)); simpl; trivial. - destruct (df_eval_tree_deriv env l (x, DTfloat)); simpl; trivial. - Qed. -*) - -End real_pfs. - -(* - Section FreeVariablesExample. - (* We need to open the string scope in order to use "a" as a string. *) - Open Scope string_scope. - Theorem ex1 : (df_free_variables (Plus (Var "a") (Var "b"))) = "a"::"b"::nil. - Proof. - (* Reflexivity doesn't need syntactically identical things on either side of =. - * It suffices that the left-hand side beta-reduced to the right-hand side. *) - reflexivity. - Qed. - - Close Scope string_scope. - - End FreeVariablesExample. -*) diff --git a/coq/NeuralNetworks/testderiv.v b/coq/NeuralNetworks/testderiv.v deleted file mode 100644 index c2d6e9f6..00000000 --- a/coq/NeuralNetworks/testderiv.v +++ /dev/null @@ -1,1802 +0,0 @@ -Require Import String. -Require Import EquivDec. -Require Import RelationClasses. -Require Import List. -Require Import NPeano. -Require Import Lra Lia. -Require Reals. -Require Import Eqdep_dec. - -Require Import Floatish. -Require Import Utils. -Require Import derivlemmas. -Require Import Coquelicot.Hierarchy. -Require Import Coquelicot.Derive. -Require Import Coquelicot.Rcomplements. -Require Import DefinedFunctions. -Require FunctionalExtensionality. - -Set Bullet Behavior "Strict Subproofs". - -Section DefinedFunctions. - - Context {floatish_impl:floatish}. - Local Open Scope float. - -Section real_pfs. - - Local Existing Instance floatish_R. - Import Reals. - Import List. - - (* following returns None if not-differentiable *) - Fixpoint df_eval_deriv_exact {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v:var_type) : option (definition_function_types_interp T) - := (match df with - | Number _ _ => Some 0 - | Constant t _ x => Some - match t return definition_function_types_interp t with - | DTfloat => 0 - | DTVector n => ConstVector n 0 - | DTMatrix m n => ConstMatrix m n 0 - end - | DVector n _ dfs => vectoro_to_ovector (fun i => df_eval_deriv_exact σ (dfs i) v) - | DMatrix n m _ df => matrixo_to_omatrix (fun i j => df_eval_deriv_exact σ (df i j) v) - | Var x _ => Some (let t:=snd x in - match t return definition_function_types_interp t with - | DTfloat => if x == v then 1 else 0 - | DTVector n => ConstVector n (if x == v then 1 else 0) - | DTMatrix m n => ConstMatrix m n (if x == v then 1 else 0) - end) - | Plus _ l r => - match df_eval_deriv_exact σ l v, df_eval_deriv_exact σ r v with - | Some le, Some lr => Some (le + lr) - | _, _ => None - end - | Minus _ l r => - match df_eval_deriv_exact σ l v, df_eval_deriv_exact σ r v with - | Some le, Some lr => Some (le - lr) - | _, _ => None - end - | Times _ l r => - match df_eval σ l, df_eval_deriv_exact σ l v, df_eval σ r, df_eval_deriv_exact σ r v with - | Some le, Some ld, Some re, Some rd => - Some (le * rd + - (ld * re)) - | _, _, _, _ => None - end - | Divide _ l r => - match df_eval σ l, df_eval_deriv_exact σ l v, df_eval σ r, df_eval_deriv_exact σ r v with - | Some le, Some ld, Some re, Some rd => - if Feq re 0 then None else - Some ((ld / re) - ((le * rd) / (re * re))) - | _, _, _, _ => None - end - | Square _ e => - match df_eval σ e, df_eval_deriv_exact σ e v with - | Some ee, Some ed => Some (2 * ee * ed) - | _, _ => None - end - | Exp _ e => - match df_eval σ e, df_eval_deriv_exact σ e v with - | Some ee, Some ed => Some (ed * Fexp ee) - | _, _ => None - end - | Log _ e => - match df_eval σ e, df_eval_deriv_exact σ e v with - | Some ee, Some ed => - if ee > 0 then Some (ed / ee) else None - | _, _ => None - end - | Abs _ e => - match df_eval σ e, df_eval_deriv_exact σ e v with - | Some ee, Some ed => - if Feq ee 0 then - (if Feq ed 0 then Some 0 else None) - else Some (ed * (sign ee)) - | _, _ => None - end - | Sign _ e => - match df_eval σ e, df_eval_deriv_exact σ e v with - | Some ee, Some ed => - if Feq ee 0 then None else Some 0 - | _, _ => None - end - | PSign _ e => - match df_eval σ e, df_eval_deriv_exact σ e v with - | Some ee, Some ed => - if Feq ee 0 then None else Some 0 - | _, _ => None - end - | Max _ l r => - match df_eval σ l, df_eval σ r, df_eval_deriv_exact σ l v, df_eval_deriv_exact σ r v with - | Some le, Some re, Some ld, Some rd => - if Feq le re then - (if Feq ld rd then Some ld else None) - else - (if le < re then Some rd else Some ld) - | _, _, _, _=> None - end - | VectorElem n _ l i => - match (df_eval_deriv_exact σ l v) with - | Some l' => Some (l' i) - | _ => None - end - | MatrixElem m n _ l i j => - match (df_eval_deriv_exact σ l v) with - | Some l' => Some (l' i j) - | _ => None - end - | VectorDot n _ l r => - match df_eval σ l, df_eval_deriv_exact σ l v, df_eval σ r, df_eval_deriv_exact σ r v with - | Some le, Some ld, Some re, Some rd => - Some (vsum (fun j => (le j) * (rd j) + (ld j) * (re j))) - | _, _, _, _ => None - end - | VectorSum n _ l => - match df_eval_deriv_exact σ l v with - | Some ld => - Some (vsum ld) - | _ => None - end - | MatrixSum n m _ l => - match df_eval_deriv_exact σ l v with - | Some ld => - Some (msum ld) - | _ => None - end - | VectorScalMult n _ x r => - match df_eval σ x, df_eval_deriv_exact σ x v, df_eval σ r, df_eval_deriv_exact σ r v with - | Some xe, Some xd, Some re, Some rd => Some (fun j => xe * (rd j) + xd * (re j)) - | _, _, _, _ => None - end - | MatrixScalMult n m _ x r => - match df_eval σ x, df_eval_deriv_exact σ x v, df_eval σ r, df_eval_deriv_exact σ r v with - | Some xe, Some xd, Some re, Some rd => Some (fun i j => xe * (rd i j) + xd * (re i j)) - | _, _, _, _ => None - end - | MatrixVectorMult n m _ l r => - match df_eval σ l, df_eval_deriv_exact σ l v, df_eval σ r, df_eval_deriv_exact σ r v with - | Some le, Some ld, Some re, Some rd => - Some (fun i => vsum (fun j => (le i j)*(rd j) + (ld i j)*(re j))) - | _, _, _, _ => None - end - | MatrixVectorAdd n m _ l r => - match df_eval_deriv_exact σ l v, df_eval_deriv_exact σ r v with - | Some le, Some re => - Some (fun i j => (le i j) + (re i)) - | _, _ => None - end - | MatrixMult n m p _ l r => - match df_eval σ l, df_eval_deriv_exact σ l v, df_eval σ r, df_eval_deriv_exact σ r v with - | Some le, Some ld, Some re, Some rd => - Some (fun i k => vsum (fun j => (le i j)*(rd j k) + (ld i j)*(re j k))) - | _, _, _, _ => None - end - | VectorPlus n _ l r => - match df_eval_deriv_exact σ l v, df_eval_deriv_exact σ r v with - | Some l', Some r' => Some (fun i => (l' i) + (r' i)) - | _, _ => None - end - | VectorMinus n _ l r => - match df_eval_deriv_exact σ l v, df_eval_deriv_exact σ r v with - | Some l', Some r' => Some (fun i => (l' i) - (r' i)) - | _, _ => None - end - | MatrixPlus n m _ l r => - match df_eval_deriv_exact σ l v, df_eval_deriv_exact σ r v with - | Some l', Some r' => Some (fun i j => (l' i j) + (r' i j)) - | _, _ => None - end - | MatrixMinus n m _ l r => - match df_eval_deriv_exact σ l v, df_eval_deriv_exact σ r v with - | Some l', Some r' => Some (fun i j => (l' i j) - (r' i j)) - | _, _ => None - end - | VectorApply n _ x s r => - match df_eval σ r, df_eval_deriv_exact σ r v with - | Some re, Some rd => - vectoro_to_ovector - (fun i => - let xv := (x, DTfloat):var_type in - match df_eval_deriv_exact (cons (mk_env_entry xv (re i)) nil) s xv with - | Some sd => Some ((rd i) * sd) - | _ => None - end) - | _, _ => None - end - | MatrixApply n m _ x s r => - match df_eval σ r, df_eval_deriv_exact σ r v with - | Some re, Some rd => - matrixo_to_omatrix - (fun i j => - let xv := (x, DTfloat):var_type in - match df_eval_deriv_exact (cons (mk_env_entry xv (re i j)) nil) s xv with - | Some sd => Some ((rd i j) * sd) - | _ => None - end) - | _, _ => None - end - | VLossfun n _ v1 v2 s l r => - match df_eval σ l, df_eval_deriv_exact σ l v with - | Some le, Some ld => - match (vectoro_to_ovector - (fun i => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - match df_eval_deriv_exact (cons (mk_env_entry xv1 (le i)) - (cons (mk_env_entry xv2 (r i)) nil)) s xv1 with - | Some sd => Some ((ld i) * sd) - | _ => None - end)) with - | Some vv => Some (vsum vv) - | _ => None - end - | _, _ => None - end - | MLossfun n m _ v1 v2 s l r => - match df_eval σ l, df_eval_deriv_exact σ l v with - | Some le, Some ld => - match (matrixo_to_omatrix - (fun i j => - let xv1 := (v1, DTfloat):var_type in - let xv2 := (v2, DTfloat):var_type in - match df_eval_deriv_exact (cons (mk_env_entry xv1 (le i j)) - (cons (mk_env_entry xv2 (r i j)) nil)) s xv1 with - | Some sd => Some ((ld i j) * sd) - | _ => None - end)) with - | Some vv => - if Nat.eq_dec m 0%nat then None else Some ((msum vv) / (FfromZ (Z.of_nat m))) - | _ => None - end - | _, _ => None - end - end). - - - Definition addBinding σ v x := (mk_env_entry (v,DTfloat) x)::σ. - - Definition df_eval_at_point {Ann T} σ (df:DefinedFunction Ann T) v x - := df_eval (addBinding σ v x) df. - - Definition df_R {Ann} (σ:df_env) (df:DefinedFunction Ann DTfloat) v : R -> R - := fun x => match df_eval_at_point σ df v x with - | Some y => y - | None => 0%R - end. - - Definition ex_deriv_df {Ann} σ (df:DefinedFunction Ann DTfloat) v (x:R) - := fully_closed_over df ((v,DTfloat)::map (@projT1 _ _) σ) /\ - ex_derive (df_R σ df v) x. - - Definition is_deriv_df {Ann} σ (df:DefinedFunction Ann DTfloat) v (x y:R) - := fully_closed_over df ((v,DTfloat)::map (@projT1 _ _) σ) /\ - is_derive (df_R σ df v) x y. - - - Lemma eval_at_point_fully_closed_total {T} (σ:df_env) (df:DefinedFunction UnitAnn T) v x : - let vl := (v,DTfloat)::map (fun ve => projT1 ve) σ in - fully_closed_over df vl -> - {d:definition_function_types_interp T | df_eval_at_point σ df v x = Some d}. - Proof. - intros. - unfold df_eval_at_point. - destruct (eval_fully_closed_total (addBinding σ v x) df) as [dd pfd] - ; simpl; eauto. - Defined. - - Lemma eval_at_point_diferentiable_total {σ} {df:DefinedFunction UnitAnn DTfloat} {v x y} : - is_deriv_df σ df v x y -> - {xve | df_eval_at_point σ df v x = Some xve}. - Proof. - intros [closed _ ]. - destruct (eval_at_point_fully_closed_total σ df v x); eauto. - Defined. - - Definition eval_differentiable_at_point {σ} {df:DefinedFunction UnitAnn DTfloat} {v x y} - (pf_deriv:is_deriv_df σ df v x y) := - proj1_sig (eval_at_point_diferentiable_total pf_deriv). - - Lemma is_derive_df_exp - (df:DefinedFunction UnitAnn DTfloat) (σ:df_env) v x y - (pf_deriv:is_deriv_df σ df v x y) : - forall a, - is_deriv_df σ (Exp a df) v x (y * exp (eval_differentiable_at_point pf_deriv)). - Proof. - unfold eval_differentiable_at_point. - destruct (eval_at_point_diferentiable_total pf_deriv) as [xve eqqx]; simpl. - unfold is_deriv_df; simpl; destruct pf_deriv as [base_closed base_deriv]. - split; trivial. - generalize (is_derive_comp exp (df_R σ df v) x (exp xve) y) - ; intros isd. - unfold df_R, df_eval_at_point in *. - eapply is_derive_ext; [ | eapply isd]; trivial. - - intros; simpl. - match_option; simpl. - eelim eval_fully_closed_not_none; [ | eapply eqq]. - simpl; trivial. - - rewrite eqqx. - apply is_derive_exp. - Qed. - - Lemma is_derive_df_mult - (df1 df2:DefinedFunction UnitAnn DTfloat) (σ:df_env) v x y1 y2 - (pf_deriv1:is_deriv_df σ df1 v x y1) - (pf_deriv2:is_deriv_df σ df2 v x y2) : - forall a, - is_deriv_df σ (Times a df1 df2) v x ((y1 * eval_differentiable_at_point pf_deriv2 + eval_differentiable_at_point pf_deriv1 * y2)). - Proof. - unfold eval_differentiable_at_point. - intros. - destruct (eval_at_point_diferentiable_total pf_deriv1) as [xve1 eqqx1]; simpl. - destruct (eval_at_point_diferentiable_total pf_deriv2) as [xve2 eqqx2]; simpl. - unfold is_deriv_df; simpl - ; destruct pf_deriv1 as [base_closed1 base_deriv1] - ; destruct pf_deriv2 as [base_closed2 base_deriv2]. - split; [tauto | ]. - generalize (is_derive_mult (df_R σ df1 v) (df_R σ df2 v) x y1 y2 base_deriv1 base_deriv2) - ; intros HH. - unfold df_R in *. - rewrite eqqx1, eqqx2 in HH. - eapply is_derive_ext; [ | eapply HH]; trivial. - - intros; simpl. - unfold df_eval_at_point; simpl. - repeat match_option; unfold mult; simpl; lra. - Qed. - -Tactic Notation "DefinedFunction_scalar_cases" tactic(first) ident(c) := - first; - [ Case_aux c "Number"%string - | Case_aux c "Constant"%string - | Case_aux c "Var"%string - | Case_aux c "Plus"%string - | Case_aux c "Minus"%string - | Case_aux c "Times"%string - | Case_aux c "Divide"%string - | Case_aux c "Square"%string - | Case_aux c "Exp"%string - | Case_aux c "Log"%string - | Case_aux c "Abs"%string - | Case_aux c "Sign"%string - | Case_aux c "PSign"%string - | Case_aux c "Max"%string]. - - - - - Ltac refl_simpler := - repeat - match goal with - | [H: @eq var_type _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H) - | [H: @equiv var_type _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H) - | [H: @eq definition_function_types _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H) - | [H: @equiv definition_function_types _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H) - end. - - - Lemma df_R_total_plus σ a (l r:DefinedFunction UnitAnn DTfloat) v x : - fully_closed_over l ((v,DTfloat)::map (@projT1 _ _) σ) -> - fully_closed_over r ((v,DTfloat)::map (@projT1 _ _) σ) -> - df_R σ (Plus a l r) v x = df_R σ l v x + df_R σ r v x. - Proof. - simpl. - intros. - destruct (eval_fully_closed_total (addBinding σ v x) l); simpl; trivial. - destruct (eval_fully_closed_total (addBinding σ v x) r); simpl; trivial. - unfold df_R, df_eval_at_point; simpl. - now rewrite e, e0. - Qed. - - Lemma df_R_total_minus σ a (l r:DefinedFunction UnitAnn DTfloat) v x : - fully_closed_over l ((v,DTfloat)::map (@projT1 _ _) σ) -> - fully_closed_over r ((v,DTfloat)::map (@projT1 _ _) σ) -> - df_R σ (Minus a l r) v x = df_R σ l v x - df_R σ r v x. - Proof. - simpl. - intros. - destruct (eval_fully_closed_total (addBinding σ v x) l); simpl; trivial. - destruct (eval_fully_closed_total (addBinding σ v x) r); simpl; trivial. - unfold df_R, df_eval_at_point; simpl. - now rewrite e, e0. - Qed. - - Lemma df_R_total_times σ a (l r:DefinedFunction UnitAnn DTfloat) v x : - fully_closed_over l ((v,DTfloat)::map (@projT1 _ _) σ) -> - fully_closed_over r ((v,DTfloat)::map (@projT1 _ _) σ) -> - df_R σ (Times a l r) v x = df_R σ l v x * df_R σ r v x. - Proof. - simpl. - intros. - destruct (eval_fully_closed_total (addBinding σ v x) l); simpl; trivial. - destruct (eval_fully_closed_total (addBinding σ v x) r); simpl; trivial. - unfold df_R, df_eval_at_point; simpl. - now rewrite e, e0. - Qed. - - Lemma df_R_total_divide σ a (l r:DefinedFunction UnitAnn DTfloat) v x : - fully_closed_over l ((v,DTfloat)::map (@projT1 _ _) σ) -> - fully_closed_over r ((v,DTfloat)::map (@projT1 _ _) σ) -> - df_R σ (Divide a l r) v x = df_R σ l v x / df_R σ r v x. - Proof. - simpl. - intros. - destruct (eval_fully_closed_total (addBinding σ v x) l); simpl; trivial. - destruct (eval_fully_closed_total (addBinding σ v x) r); simpl; trivial. - unfold df_R, df_eval_at_point; simpl. - now rewrite e, e0. - Qed. - - Lemma floatish_sign : - forall (x:R), sign x = FloatishOps.sign x. - Proof. - intros. - unfold sign, FloatishOps.sign; simpl. - case_eq (Rlt_dec x 0); intros; trivial. - - match_case; intros. - destruct s; lra. - - match_case; intros. - + destruct s; case_eq (Rgt_dec x 0); simpl; intros; lra. - + lra. - Qed. - - Lemma pos_sign_psign: - forall (x:R), psign x = pos_sign x. - Proof. - unfold psign, pos_sign. - intros. - simpl. - now destruct (Rge_dec x 0). - Qed. - - Lemma Rmax_Fmax : - forall (x y:R), Rmax x y = Fmax x y. - unfold Rmax, Fmax; simpl; intros. - match_case; intros. - - case_eq (Rlt_dec x y); intros; trivial. - lra. - - case_eq (Rlt_dec x y); intros; trivial. - lra. - Qed. - - - Theorem df_eval_deriv_exact_correct σ (df:DefinedFunction UnitAnn DTfloat) v (x:R) y - : is_scalar_function df -> - fully_closed_over df ((v,DTfloat)::map (@projT1 _ _) σ) -> - df_eval_deriv_exact (addBinding σ v x) df (v,DTfloat) = Some y -> - is_derive (df_R σ df v) x y. - Proof. - simpl. - intros is_scalar. - generalize is_scalar. - revert y. - pattern df. - revert df is_scalar. - DefinedFunction_scalar_cases (apply is_scalar_function_ind) Case; simpl; intros. - - Case "Number"%string. - unfold df_R, df_eval_at_point; simpl. - inversion H0; subst. - now apply (@is_derive_const R_AbsRing). - - Case "Constant"%string. - unfold df_R, df_eval_at_point; simpl. - inversion H0; subst. - now apply (@is_derive_const R_AbsRing). - - Case "Var"%string. - unfold df_R, df_eval_at_point; simpl. - inversion H0; subst. - simpl. - unfold equiv_dec, vart_eqdec. - destruct (vart_dec (sv, DTfloat) (v, DTfloat)). - + refl_simpler; simpl. - now apply (@is_derive_id R_AbsRing). - + now apply (@is_derive_const R_AbsRing). - - Case "Plus"%string. - destruct H1. - do 2 match_option_in H2. - invcs H2. - destruct is_scalar as [isc1 isc2]. - specialize (H _ isc1 H1 eqq). - specialize (H0 _ isc2 H3 eqq0). - eapply is_derive_ext. - + intros. - symmetry. - now apply df_R_total_plus. - + generalize (@is_derive_plus R_AbsRing); simpl - ; intros HH; now apply HH. - - Case "Minus"%string. - destruct H1. - do 2 match_option_in H2. - invcs H2. - destruct is_scalar as [isc1 isc2]. - specialize (H _ isc1 H1 eqq). - specialize (H0 _ isc2 H3 eqq0). - eapply is_derive_ext. - + intros. - symmetry. - now apply df_R_total_minus. - + generalize (@is_derive_minus R_AbsRing); simpl - ; intros HH; now apply HH. - - Case "Times"%string. - destruct H1. - do 4 match_option_in H2. - invcs H2. - destruct is_scalar as [isc1 isc2]. - specialize (H _ isc1 H1 eqq0). - specialize (H0 _ isc2 H3 eqq2). - eapply is_derive_ext. - + intros. - symmetry. - now apply df_R_total_times. - + generalize Derive.is_derive_mult - ; unfold plus, mult; simpl; intros HH. - replace (d * d2 + d0 * d1)%R with - (d0 * (df_R σ r v x) + (df_R σ l v x) * d2). - * apply HH; trivial. - * unfold df_R, df_eval_at_point; simpl. - rewrite eqq; rewrite eqq1. - lra. - - Case "Divide"%string. - destruct H1. - do 4 match_option_in H2. - invcs H2. - destruct is_scalar as [isc1 isc2]. - specialize (H _ isc1 H1 eqq0). - specialize (H0 _ isc2 H3 eqq2). - eapply is_derive_ext. - + intros. - symmetry. - now apply df_R_total_divide. - + generalize is_derive_div; simpl; intros HH. - replace (d0 / d1 - d * d2 / (d1 * d1)) with - ((d0 * (df_R σ r v x) - (df_R σ l v x) * d2) / ((df_R σ r v x)*((df_R σ r v x)*1))). - * destruct (Req_EM_T d1 0); [congruence |]. - inversion H5. - specialize (HH (df_R σ l v) (df_R σ r v) x d0 d2). - replace (d0 / d1 - d * d2 / (d1 * d1))%R with - ((d0 * df_R σ r v x - df_R σ l v x * d2) / (df_R σ r v x * (df_R σ r v x * 1))). - apply HH; trivial. - unfold df_R, df_eval_at_point; simpl. - now rewrite eqq1. - unfold df_R, df_eval_at_point; simpl. - rewrite eqq; rewrite eqq1. - now field. - * unfold df_R, df_eval_at_point; simpl. - rewrite eqq; rewrite eqq1. - destruct (Req_EM_T d1 0); [congruence |]. - now field; trivial. - - Case "Square"%string. - do 2 match_option_in H1. - invcs H1. - specialize (H _ is_scalar H0 eqq0). - assert (forall t, (Rsqr (df_R σ e v t) = df_R σ (Square a e) v t)). - + intros; simpl. - unfold df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v t) e); simpl; trivial. - rewrite e0; unfold Rsqr; trivial. - + eapply is_derive_ext. - * intros; simpl. - now rewrite H1. - * replace (2 * d * d0)%R with (d0 * (2 * d))%R by lra. - apply (@is_derive_comp R_AbsRing); trivial. - replace (d) with (df_R σ e v x). - -- apply is_derive_sqr. - -- unfold df_R, df_eval_at_point; simpl. - now rewrite eqq. - - Case "Exp"%string. - do 2 match_option_in H1. - invcs H1. - specialize (H _ is_scalar H0 eqq0). - assert (forall t, (exp (df_R σ e v t) = df_R σ (Exp a e) v t)). - + intros; simpl. - unfold df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v t) e); simpl; trivial. - rewrite e0; trivial. - + eapply is_derive_ext. - * intros; simpl. - now rewrite H1. - * apply (@is_derive_comp R_AbsRing); trivial. - replace (d) with (df_R σ e v x). - -- apply is_derive_exp. - -- unfold df_R, df_eval_at_point; simpl. - now rewrite eqq. - - Case "Log"%string. - do 2 match_option_in H1. - invcs H1. - specialize (H _ is_scalar H0 eqq0). - assert (forall t, (ln (df_R σ e v t) = df_R σ (Log a e) v t)). - + intros; simpl. - unfold df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v t) e); simpl; trivial. - rewrite e0; trivial. - + eapply is_derive_ext. - * intros; simpl. - now rewrite H1. - * destruct (Rgt_dec d 0); [|congruence]. - inversion H3. - apply (@is_derive_comp R_AbsRing); trivial. - replace (d) with (df_R σ e v x). - -- apply is_derive_ln. - unfold df_R, df_eval_at_point; simpl. - rewrite eqq. - lra. - -- unfold df_R, df_eval_at_point; simpl. - now rewrite eqq. - - Case "Abs"%string. - do 2 match_option_in H1. - invcs H1. - specialize (H _ is_scalar H0 eqq0). - assert (forall t, (Rabs (df_R σ e v t) = df_R σ (Abs a e) v t)). - + intros; simpl. - unfold df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v t) e); simpl; trivial. - rewrite e0; trivial. - + eapply is_derive_ext. - * intros; simpl. - now rewrite H1. - * destruct ( Req_EM_T d 0 ). - -- destruct (Req_EM_T d0 0). - ++inversion H3. - apply is_derive_Rabs_df0. - rewrite e1 in H. - apply H. - ++ congruence. - -- inversion H3. - apply (@is_derive_comp R_AbsRing); trivial. - replace (d) with (df_R σ e v x) - ; unfold df_R, df_eval_at_point; simpl; rewrite eqq; trivial. - apply is_derive_abs; lra. - - Case "Sign"%string. - match_option_in H1. - match_option_in H1. - destruct (Req_EM_T d 0); [congruence|]. - invcs H1. - specialize (H _ is_scalar H0 eqq0). - assert (forall t, (sign (df_R σ e v t) = df_R σ (Sign a e) v t)). - + intros; simpl. - unfold df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v t) e); simpl; trivial. - rewrite e0; trivial. - apply floatish_sign. - + eapply is_derive_ext. - * intros; simpl. - now rewrite H1. - * replace (0)%R with (d0 * 0)%R by lra. - apply (@is_derive_comp R_AbsRing); trivial. - apply is_derive_sign. - unfold df_R, df_eval_at_point; simpl. - rewrite eqq; lra. - - Case "PSign"%string. - match_option_in H1. - match_option_in H1. - destruct (Req_EM_T d 0); [congruence|]. - invcs H1. - specialize (H _ is_scalar H0 eqq0). - assert (forall t, (psign (df_R σ e v t) = df_R σ (PSign a e) v t)). - + intros; simpl. - unfold df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v t) e); simpl; trivial. - rewrite e0; trivial. - apply pos_sign_psign. - + eapply is_derive_ext. - * intros; simpl. - now rewrite H1. - * replace (0)%R with (d0 * 0)%R by lra. - apply (@is_derive_comp R_AbsRing); trivial. - apply is_derive_psign. - unfold df_R, df_eval_at_point; simpl. - rewrite eqq; lra. - - Case "Max"%string. - do 2 match_option_in H2. - destruct H1. - destruct is_scalar as [isc1 isc2]. - assert (forall t, (Rmax (df_R σ l v t) (df_R σ r v t)) = df_R σ (Max a l r) v t). - + intros; simpl. - unfold df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v t) l); simpl; trivial. - destruct (eval_fully_closed_total (addBinding σ v t) r); simpl; trivial. - rewrite e, e0; trivial. - apply Rmax_Fmax. - + eapply is_derive_ext. - * intros; simpl. - now rewrite H4. - * match_option_in H2. - match_option_in H2. - destruct (Req_EM_T d d0). - -- destruct (Req_EM_T d1 d2). - ++ invcs H2. - apply is_derive_max_alt2. - apply H; trivial. - apply H0; trivial. - ++ congruence. - -- destruct (Rlt_dec d d0). - ++ invcs H2. - specialize (H _ isc1 H1 eqq1). - specialize (H0 _ isc2 H3 eqq2). - replace (y) with ((d1 + y + (d1-y)*sign(df_R σ l v x - df_R σ r v x))/2). - ** apply is_derive_max; trivial. - unfold df_R, df_eval_at_point; simpl. - now rewrite eqq, eqq0. - ** unfold df_R, df_eval_at_point; simpl. - rewrite eqq, eqq0. - unfold sign; simpl. - destruct (total_order_T 0 (d - d0)); [destruct s; lra|lra]. - ++ invcs H2. - specialize (H _ isc1 H1 eqq1). - specialize (H0 _ isc2 H3 eqq2). - replace (y) with ((y + d2 + (y - d2)*sign(df_R σ l v x - df_R σ r v x))/2). - ** apply is_derive_max; trivial. - unfold df_R, df_eval_at_point; simpl. - now rewrite eqq, eqq0. - ** unfold df_R, df_eval_at_point; simpl. - rewrite eqq, eqq0. - unfold sign; simpl. - destruct (total_order_T 0 (d - d0)); [destruct s; lra | lra]. - Qed. - - Definition df_R_vec {Ann} {n} (σ:df_env) (df:DefinedFunction Ann (DTVector n)) v : - R -> Vector R n - := fun x => match df_eval_at_point σ df v x with - | Some y => y - | None => ConstVector n 0 - end. - - Definition df_R_mat {Ann} {n m} (σ:df_env) (df:DefinedFunction Ann (DTMatrix n m)) v : - R -> Matrix R n m - := fun x => match df_eval_at_point σ df v x with - | Some y => y - | None => ConstMatrix n m 0 - end. - - Definition df_R_gen Ann T σ : - (DefinedFunction Ann T) -> SubVar -> R -> (definition_function_types_interp T) := - match T with - | DTfloat => fun df v => df_R σ df v - | DTVector n => fun df v => df_R_vec σ df v - | DTMatrix n m => fun df v => df_R_mat σ df v - end. - - Definition is_derive_vec {n} (f : R -> Vector R n) (x:R) (df : Vector R n) := - forall i, is_derive (fun x0 => f x0 i) x (df i). - - Definition is_derive_mat {n m} (f : R -> Matrix R n m) (x:R) (df : Matrix R n m) := - forall i j, is_derive (fun x0 => f x0 i j) x (df i j). - - Definition is_derive_gen {T} (f: R->definition_function_types_interp T) (x:R) - (df : definition_function_types_interp T) - := - (match T return (R -> definition_function_types_interp T) -> - definition_function_types_interp T -> Prop - with - | DTfloat => fun f df => is_derive f x df - | DTVector n => fun f df => is_derive_vec f x df - | DTMatrix n m => fun f df => is_derive_mat f x df - end) f df. - - Definition vec_to_nat_fun {n:nat} (v:Vector R n) (i:nat) : R := - match lt_dec i n with - | left pf => v (exist _ i pf) - | right _ => 0 - end. - - Lemma vec_to_nat_fun_vcons_end {n} (v : Vector R n) b : - vec_to_nat_fun (vcons b v) n = b. - Proof. - unfold vec_to_nat_fun; simpl. - destruct (lt_dec n (S n)); [ | lia]. - destruct (Nat.eq_dec n n); [ | lia]. - trivial. - Qed. - - Lemma vec_to_nat_fun_vcons_nend {n} (v : Vector R n) b m (pf:(m sum_n (vec_to_nat_fun v) (n-1) = r)). - - unfold vec_to_nat_fun, sum_n, sum_n_m, Iter.iter, Iter.iter_nat, plus, zero. - simpl. - lra. - - unfold vec_to_nat_fun, sum_n, sum_n_m, Iter.iter, Iter.iter_nat, plus, zero, Datatypes.id. - simpl; intros. - lra. - - intros. - simpl. - rewrite <- H. - destruct n0. - + unfold vec_to_nat_fun, sum_n, sum_n_m, Iter.iter, Iter.iter_nat, plus, zero, Datatypes.id. - simpl. - lra. - + simpl. - rewrite Nat.sub_0_r. - rewrite sum_Sn. - unfold plus; simpl. - rewrite vec_to_nat_fun_vcons_end. - rewrite Rplus_comm. - f_equal; simpl. - erewrite sum_n_ext_loc; [ reflexivity | ]; intros. - apply vec_to_nat_fun_vcons_nend. - lia. - Qed. - - Lemma is_derive_vsum {n} (vf : R -> Vector R n) (x:R) (df : Vector R n) : - is_derive_vec vf x df -> - is_derive (fun x0 => vsum (vf x0)) x (vsum df). - Proof. - unfold is_derive_vec; intro. - apply (is_derive_ext (fun x0 => sum_n (vec_to_nat_fun (vf x0)) (n-1)%nat)) - ; [intros; apply sum_n_vsum |]. - replace (@vsum floatish_R n df) with (sum_n (vec_to_nat_fun df) (n-1)%nat) - ; [|intros; apply sum_n_vsum ]. - apply (@is_derive_sum_n R_AbsRing); intros. - unfold vec_to_nat_fun. - destruct (lt_dec k n). - - apply H. - - apply (@is_derive_const R_AbsRing). - Qed. - - Lemma is_derive_msum {n m} (mf : R -> Matrix R n m) (x:R) (df : Matrix R n m) : - is_derive_mat mf x df -> - is_derive (fun x0 => msum (mf x0)) x (msum df). - Proof. - simpl. - unfold is_derive_mat; intro. - unfold msum. - apply is_derive_vsum. - unfold is_derive_vec; simpl; intro. - rewrite vmap_nth. - apply (is_derive_ext (fun x0 => (@vsum floatish_R m) (mf x0 i))). - - intro; now rewrite vmap_nth. - - apply is_derive_vsum. - now unfold is_derive_vec. - Qed. - - Theorem df_eval_deriv_exact_gen_correct {T} σ (df:DefinedFunction UnitAnn T) v (x:R) y - : fully_closed_over df ((v,DTfloat)::map (@projT1 _ _) σ) -> - df_eval_deriv_exact (addBinding σ v x) df (v,DTfloat) = Some y -> - is_derive_gen (df_R_gen UnitAnn T σ df v) x y. - Proof. - revert σ v x. - DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case - ; simpl; intros σ v0 xx; intros. - - Case "Number"%string. - unfold df_R, df_eval_at_point; simpl. - inversion H0; subst. - now apply (@is_derive_const R_AbsRing). - - Case "Constant"%string. - unfold df_R_gen, is_derive_gen; simpl. - destruct t. - + unfold df_R; simpl; invcs H0. - now apply (@is_derive_const R_AbsRing). - + unfold df_R_vec; simpl; invcs H0. - unfold is_derive_vec, ConstVector; intro. - now apply (@is_derive_const R_AbsRing). - + unfold df_R_mat; simpl; invcs H0. - unfold is_derive_mat, ConstMatrix; intros. - now apply (@is_derive_const R_AbsRing). - - Case "DVector"%string. - unfold is_derive_vec; intros. - specialize (vectoro_to_ovector_forall_some_f H1); intros. - specialize (H2 i); simpl in H2. - rewrite vforall_forall in H0. - assert (H0c := H0). - specialize (H0 i). - specialize (H i (y i) σ v0 xx H0 H2); simpl in *. - apply (is_derive_ext (df_R σ (x i) v0) ); trivial; intros. - unfold df_R_vec, df_R, df_eval_at_point; simpl. - match_option. - + match_option. - * specialize (vectoro_to_ovector_forall_some_f eqq0); intros. - specialize (H3 i); simpl in H3. - now rewrite eqq in H3; invcs H3. - * unfold ConstVector. - apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0. - specialize (H0c x0). - apply (eval_fully_closed_not_none (addBinding σ v0 t) (x x0)) in H0c; tauto. - + match_option. - specialize (vectoro_to_ovector_forall_some_f eqq0); intros. - specialize (H3 i); simpl in H3; congruence. - - Case "DMatrix"%string. - unfold is_derive_mat; intros. - unfold matrixo_to_omatrix in H1. - specialize (vectoro_to_ovector_forall_some_f H1); intros. - specialize (H2 i); simpl in H2. - specialize (vectoro_to_ovector_forall_some_f H2); intros. - specialize (H3 j); simpl in H3. - rewrite vforall_forall in H0. - assert (H0c := H0). - specialize (H0 i). - rewrite vforall_forall in H0. - specialize (H0 j). - specialize (H i j (y i j) σ v0 xx H0 H3); simpl in H. - apply (is_derive_ext (df_R σ (x i j) v0)); trivial; intros. - unfold df_R_mat, df_R, df_eval_at_point; simpl. - match_option. - + match_option. - * unfold matrixo_to_omatrix in eqq0. - specialize (vectoro_to_ovector_forall_some_f eqq0); intros. - specialize (H4 i); simpl in H4. - specialize (vectoro_to_ovector_forall_some_f H4); intros. - specialize (H5 j); simpl in H5. - now rewrite eqq in H5; invcs H5. - * unfold ConstMatrix. - unfold matrixo_to_omatrix in eqq0. - apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0. - apply vectoro_to_ovector_exists_None in e; destruct e. - specialize (H0c x0). - rewrite vforall_forall in H0c. - specialize (H0c x1). - apply (eval_fully_closed_not_none (addBinding σ v0 t) (x x0 x1)) in H0c; tauto. - + match_option. - unfold matrixo_to_omatrix in eqq0. - specialize (vectoro_to_ovector_forall_some_f eqq0); intros. - specialize (H4 i); simpl in H4. - specialize (vectoro_to_ovector_forall_some_f H4); intros. - specialize (H5 j); simpl in H5; congruence. - - Case "Var"%string. - unfold is_derive_gen. - destruct v; unfold snd in *. - destruct d. - + unfold df_R_gen; simpl. - invcs H0. - destruct (vart_dec (s, DTfloat) (v0, DTfloat)). - * unfold equiv_dec, vart_eqdec. - inversion e. - destruct (vart_dec (v0, DTfloat) (v0, DTfloat)); [|congruence]. - apply (is_derive_ext id); [|apply (@is_derive_id R_AbsRing)]. - intro. - unfold id, df_R, df_eval_at_point. - simpl. - unfold equiv_dec, vart_eqdec. - destruct (vart_dec (v0, DTfloat) (v0, DTfloat)); [|congruence]. - now refl_simpler. - * unfold equiv_dec, vart_eqdec. - destruct (vart_dec (s, DTfloat) (v0, DTfloat)); [congruence|]. - invcs H; [congruence | ]. - unfold df_R, df_eval_at_point. - simpl. - unfold equiv_dec, vart_eqdec. - destruct (vart_dec (s, DTfloat) (v0, DTfloat)); [congruence|]. - apply (@is_derive_const R_AbsRing). - + simpl. - invcs H0. - unfold is_derive_vec; simpl; intros. - unfold ConstVector. - unfold df_R_vec, df_eval_at_point; simpl. - unfold equiv_dec, vart_eqdec. - destruct (vart_dec (s, DTVector n) (v0, DTfloat)); [congruence |]. - apply (@is_derive_const R_AbsRing). - + simpl. - invcs H0. - unfold is_derive_mat; simpl; intros. - unfold ConstMatrix. - unfold df_R_mat, df_eval_at_point; simpl. - unfold equiv_dec, vart_eqdec. - destruct (vart_dec (s, DTMatrix m n) (v0, DTfloat)); [congruence |]. - apply (@is_derive_const R_AbsRing). - - Case "Plus"%string. - destruct H. - do 2 match_option_in H0. - invcs H0. - specialize (IHdf1 _ σ v0 xx H eqq). - specialize (IHdf2 _ σ v0 xx H1 eqq0). - eapply is_derive_ext. - + intros. - symmetry. - now apply df_R_total_plus. - + generalize (@is_derive_plus R_AbsRing); simpl - ; intros HH; now apply HH. - - Case "Minus"%string. - destruct H. - do 2 match_option_in H0. - invcs H0. - specialize (IHdf1 _ σ v0 xx H eqq). - specialize (IHdf2 _ σ v0 xx H1 eqq0). - eapply is_derive_ext. - + intros. - symmetry. - now apply df_R_total_minus. - + generalize (@is_derive_minus R_AbsRing); simpl - ; intros HH; now apply HH. - - Case "Times"%string. - destruct H. - do 4 match_option_in H0. - invcs H0. - specialize (IHdf1 _ σ v0 xx H eqq0). - specialize (IHdf2 _ σ v0 xx H1 eqq2). - eapply is_derive_ext. - + intros. - symmetry. - now apply df_R_total_times. - + generalize Derive.is_derive_mult - ; unfold plus, mult; simpl; intros HH. - replace (d * d2 + d0 * d1)%R with - (d0 * (df_R σ df2 v0 xx) + (df_R σ df1 v0 xx) * d2). - * apply HH; trivial. - * unfold df_R, df_eval_at_point; simpl. - rewrite eqq; rewrite eqq1. - lra. - - Case "Divide"%string. - destruct H. - do 4 match_option_in H0. - destruct (Req_EM_T d1 0); [congruence |]. - invcs H0. - specialize (IHdf1 _ σ v0 xx H eqq0). - specialize (IHdf2 _ σ v0 xx H1 eqq2). - eapply is_derive_ext. - + intros. - symmetry. - now apply df_R_total_divide. - + generalize is_derive_div; simpl; intros HH. - replace (d0 / d1 - d * d2 / (d1 * d1))%R with - ((d0 * (df_R σ df2 v0 xx) - (df_R σ df1 v0 xx) * d2) / ((df_R σ df2 v0 xx)*((df_R σ df2 v0 xx)*1))). - * specialize (HH (df_R σ df1 v0) (df_R σ df2 v0) xx d0 d2). - apply HH; trivial. - unfold df_R, df_eval_at_point; simpl. - now rewrite eqq1. - * unfold df_R, df_eval_at_point; simpl. - rewrite eqq; rewrite eqq1. - destruct (Req_EM_T d1 0); [congruence |]. - now field; trivial. - - Case "Square"%string. - do 2 match_option_in H0. - invcs H0. - specialize (IHdf _ σ v0 xx H eqq0). - assert (forall t, (Rsqr (df_R σ df v0 t) = df_R σ (Square ann df) v0 t)). - + intros; simpl. - unfold df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial. - rewrite e; unfold Rsqr; trivial. - + eapply is_derive_ext. - * intros; simpl. - now rewrite H0. - * replace (2 * d * d0)%R with (d0 * (2 * d))%R by lra. - apply (@is_derive_comp R_AbsRing); trivial. - replace (d) with (df_R σ df v0 xx). - -- apply is_derive_sqr. - -- unfold df_R, df_eval_at_point; simpl. - now rewrite eqq. - - Case "Exp"%string. - do 2 match_option_in H0. - invcs H0. - specialize (IHdf _ σ v0 xx H eqq0). - assert (forall t, (exp (df_R σ df v0 t) = df_R σ (Exp ann df) v0 t)). - + intros; simpl. - unfold df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial. - rewrite e; trivial. - + eapply is_derive_ext. - * intros; simpl. - now rewrite H0. - * apply (@is_derive_comp R_AbsRing); trivial. - replace (d) with (df_R σ df v0 xx). - -- apply is_derive_exp. - -- unfold df_R, df_eval_at_point; simpl. - now rewrite eqq. - - Case "Log"%string. - do 2 match_option_in H0. - invcs H0. - specialize (IHdf _ σ v0 xx H eqq0). - assert (forall t, (ln (df_R σ df v0 t) = df_R σ (Log ann df) v0 t)). - + intros; simpl. - unfold df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial. - rewrite e; trivial. - + eapply is_derive_ext. - * intros; simpl. - now rewrite H0. - * destruct (Rgt_dec d 0); [|congruence]. - inversion H2. - apply (@is_derive_comp R_AbsRing); trivial. - replace (d) with (df_R σ df v0 xx). - -- apply is_derive_ln. - unfold df_R, df_eval_at_point; simpl. - rewrite eqq. - lra. - -- unfold df_R, df_eval_at_point; simpl. - now rewrite eqq. - - Case "Abs"%string. - do 2 match_option_in H0. - invcs H0. - specialize (IHdf _ σ v0 xx H eqq0). - assert (forall t, (Rabs (df_R σ df v0 t) = df_R σ (Abs ann df) v0 t)). - + intros; simpl. - unfold df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial. - rewrite e; trivial. - + eapply is_derive_ext. - * intros; simpl. - now rewrite H0. - * destruct ( Req_EM_T d 0 ). - -- destruct ( Req_EM_T d0 0 ). - ++ invcs H2. - apply is_derive_Rabs_df0. - apply IHdf. - ++ congruence. - -- inversion H2. - apply (@is_derive_comp R_AbsRing); trivial. - replace (d) with (df_R σ df v0 xx). - ++ replace (@FloatishOps.sign floatish_R (df_R σ df v0 xx)) with (sign (df_R σ df v0 xx)). - ** apply is_derive_abs. - unfold df_R, df_eval_at_point; simpl. - rewrite eqq. - lra. - ** apply floatish_sign. - ++ unfold df_R, df_eval_at_point; simpl. - now rewrite eqq. - - Case "Sign"%string. - do 2 match_option_in H0. - destruct (Req_EM_T d 0); [congruence|]. - invcs H0. - specialize (IHdf _ σ v0 xx H eqq0). - assert (forall t, (sign (df_R σ df v0 t) = df_R σ (Sign ann df) v0 t)). - + intros; simpl. - unfold df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial. - rewrite e; trivial. - apply floatish_sign. - + eapply is_derive_ext. - * intros; simpl. - now rewrite H0. - * replace (0)%R with (d0 * 0)%R by lra. - apply (@is_derive_comp R_AbsRing); trivial. - apply is_derive_sign. - unfold df_R, df_eval_at_point; simpl. - rewrite eqq; lra. - - Case "PSign"%string. - do 2 match_option_in H0. - destruct (Req_EM_T d 0); [congruence|]. - invcs H0. - specialize (IHdf _ σ v0 xx H eqq0). - assert (forall t, (psign (df_R σ df v0 t) = df_R σ (PSign ann df) v0 t)). - + intros; simpl. - unfold df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial. - rewrite e; trivial. - apply pos_sign_psign. - + eapply is_derive_ext. - * intros; simpl. - now rewrite H0. - * replace (0)%R with (d0 * 0)%R by lra. - apply (@is_derive_comp R_AbsRing); trivial. - apply is_derive_psign. - unfold df_R, df_eval_at_point; simpl. - rewrite eqq; lra. - - Case "Max"%string. - do 4 match_option_in H0. - destruct H. - assert (forall t, (Rmax (df_R σ df1 v0 t) (df_R σ df2 v0 t)) = df_R σ (Max ann df1 df2) v0 t). - + intros; simpl. - unfold df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial. - destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial. - rewrite e, e0; trivial. - apply Rmax_Fmax. - + eapply is_derive_ext. - * intros; simpl. - now rewrite H2. - * destruct (Req_EM_T d d0). - -- destruct (Req_EM_T d1 d2). - ++ invcs H0. - apply is_derive_max_alt2. - apply IHdf1; trivial. - apply IHdf2; trivial. - ++ congruence. - -- destruct (Rlt_dec d d0). - ++ invcs H0. - specialize (IHdf1 _ σ v0 xx H eqq1). - specialize (IHdf2 _ σ v0 xx H1 eqq2). - replace (y) with ((d1 + y + (d1-y)*sign(df_R σ df1 v0 xx - df_R σ df2 v0 xx))/2). - ** apply is_derive_max; trivial. - unfold df_R, df_eval_at_point; simpl. - now rewrite eqq, eqq0. - ** unfold df_R, df_eval_at_point; simpl. - rewrite eqq, eqq0. - unfold sign; simpl. - destruct (total_order_T 0 (d - d0)); [destruct s; lra | lra]. - ++ invcs H0. - specialize (IHdf1 _ σ v0 xx H eqq1). - specialize (IHdf2 _ σ v0 xx H1 eqq2). - replace (y) with ((y + d2 + (y - d2)*sign(df_R σ df1 v0 xx - df_R σ df2 v0 xx))/2). - ** apply is_derive_max; trivial. - unfold df_R, df_eval_at_point; simpl. - now rewrite eqq, eqq0. - ** unfold df_R, df_eval_at_point; simpl. - rewrite eqq, eqq0. - unfold sign; simpl. - destruct (total_order_T 0 (d - d0)); [destruct s; lra | lra]. - - Case "VectorDot"%string. - do 4 match_option_in H0. - invcs H0. - destruct H. - specialize (IHdf1 _ σ v0 xx H eqq0). - specialize (IHdf2 _ σ v0 xx H0 eqq2). - apply (is_derive_ext (fun x0 => (@vsum floatish_R n) - (fun i => ((df_R_vec σ df1 v0 x0 i) * - (df_R_vec σ df2 v0 x0 i))%R ))). - + intros. - unfold df_R_vec, df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial. - destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial. - now rewrite e, e0. - + apply is_derive_vsum. - unfold df_R_vec, df_R, df_eval_at_point; simpl. - unfold is_derive_vec; intros. - generalize is_derive_mult - ; unfold plus, mult; simpl; intros HH. - replace (d i * d2 i + d0 i * d1 i)%R with - (d0 i * (df_R_vec σ df2 v0 xx i) + (df_R_vec σ df1 v0 xx i) * d2 i). - * apply HH; trivial. - * unfold df_R_vec, df_R, df_eval_at_point; simpl. - rewrite eqq; rewrite eqq1. - lra. - - Case "VectorSum"%string. - match_option_in H0. - invcs H0. - specialize (IHdf _ σ v0 xx H eqq). - assert (forall t, (vsum (df_R_vec σ df v0 t)) = df_R σ (VectorSum ann df) v0 t). - + intros; simpl. - unfold df_R_vec, df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial. - rewrite e; unfold Rsqr; trivial. - + apply (is_derive_ext (fun x0 => (@vsum floatish_R n) (df_R_vec σ df v0 x0))). - * intros; simpl. - now rewrite H0. - * apply is_derive_vsum. - now simpl in IHdf. - - Case "MatrixSum"%string. - match_option_in H0. - invcs H0. - specialize (IHdf _ σ v0 xx H eqq). - assert (forall t, (msum (df_R_mat σ df v0 t)) = df_R σ (MatrixSum ann df) v0 t). - + intros; simpl. - unfold df_R_mat, df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial. - rewrite e; unfold Rsqr; trivial. - + apply (is_derive_ext (fun x0 => (@msum floatish_R m n) (df_R_mat σ df v0 x0))). - * intros; simpl. - now rewrite H0. - * apply is_derive_msum. - now simpl in IHdf. - - Case "VectorElem"%string. - match_option_in H0. - specialize (IHdf d σ v0 xx H eqq); simpl in IHdf. - invcs H0. - apply (is_derive_ext (fun x0 => (df_R_vec σ df v0) x0 i)); intros. - + unfold df_R_vec, df_R, df_eval_at_point. simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial. - now rewrite e. - + unfold is_derive_vec in IHdf. - apply IHdf. - - Case "MatrixElem"%string. - match_option_in H0. - specialize (IHdf d σ v0 xx H eqq); simpl in IHdf. - invcs H0. - apply (is_derive_ext (fun x0 => (df_R_mat σ df v0) x0 i j)); intros. - + unfold df_R_mat, df_R, df_eval_at_point. simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial. - now rewrite e. - + unfold is_derive_mat in IHdf. - apply IHdf. - - Case "MatrixVectorMult"%string. - do 4 match_option_in H0. - invcs H0. - destruct H. - specialize (IHdf1 _ σ v0 xx H eqq0). - specialize (IHdf2 _ σ v0 xx H0 eqq2). - unfold is_derive_vec; intros. - apply (is_derive_ext - (fun x0 => ((@matrix_vector_mult floatish_R m n) - (df_R_mat σ df1 v0 x0) - (df_R_vec σ df2 v0 x0)) i)). - + intros. - unfold df_R_vec, df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial. - destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial. - rewrite e, e0. - unfold df_R_mat, df_R, df_eval_at_point; simpl. - now rewrite e. - + unfold matrix_vector_mult. - apply is_derive_vsum. - unfold is_derive_vec; intros. - simpl. - unfold df_R_vec, df_R, df_eval_at_point; simpl. - unfold is_derive_mat; intros. - generalize Derive.is_derive_mult - ; unfold plus, mult; simpl; intros HH. - replace (d i i0 * d2 i0 + d0 i i0 * d1 i0)%R with - (d0 i i0 * (df_R_vec σ df2 v0 xx i0) + (df_R_mat σ df1 v0 xx i i0) * d2 i0)%R. - * apply HH; trivial. - simpl in IHdf1. - unfold is_derive_mat in IHdf1. - apply IHdf1. - * unfold df_R_mat, df_R_vec, df_R, df_eval_at_point; simpl. - rewrite eqq; rewrite eqq1. - lra. - - Case "MatrixVectorAdd"%string. - do 2 match_option_in H0. - invcs H0. - destruct H. - specialize (IHdf1 _ σ v0 xx H eqq). - specialize (IHdf2 _ σ v0 xx H0 eqq0). - unfold is_derive_mat; intros. - apply (is_derive_ext - (fun x0 => ((@matrix_vector_add floatish_R m n) - (df_R_mat σ df1 v0 x0) - (df_R_vec σ df2 v0 x0)) i j)). - + intros. - unfold df_R_mat, df_R_vec, df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial. - destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial. - rewrite e, e0. - now unfold matrix_vector_add. - + unfold matrix_vector_add; simpl. - simpl in *. - apply (@is_derive_plus R_AbsRing); simpl. - apply IHdf1. - apply IHdf2. - - Case "MatrixMult"%string. - do 4 match_option_in H0. - invcs H0. - destruct H. - specialize (IHdf1 _ σ v0 xx H eqq0). - specialize (IHdf2 _ σ v0 xx H0 eqq2). - unfold is_derive_mat; intros. - apply (is_derive_ext - (fun x0 => ((@matrix_mult floatish_R m p n) - (df_R_mat σ df1 v0 x0) - (df_R_mat σ df2 v0 x0)) i j)). - + intros. - unfold df_R_mat, df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial. - destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial. - now rewrite e, e0. - + unfold matrix_mult. - apply is_derive_vsum. - unfold is_derive_vec; intros. - simpl. - unfold df_R_mat, df_R, df_eval_at_point; simpl. - generalize Derive.is_derive_mult - ; unfold plus, mult; simpl; intros HH. - replace (d i i0 * d2 i0 j + d0 i i0 * d1 i0 j)%R with - (d0 i i0 * (df_R_mat σ df2 v0 xx i0 j) + (df_R_mat σ df1 v0 xx i i0) * d2 i0 j)%R. - * apply HH; trivial. - * unfold df_R_mat, df_R_vec, df_R, df_eval_at_point; simpl. - rewrite eqq; rewrite eqq1. - lra. - - Case "VectorPlus"%string. - destruct H. - do 2 match_option_in H0. - invcs H0. - specialize (IHdf1 _ σ v0 xx H eqq). - specialize (IHdf2 _ σ v0 xx H1 eqq0). - unfold is_derive_vec; intro. - simpl in *. - unfold is_derive_vec in IHdf1. - unfold is_derive_vec in IHdf2. - specialize (IHdf1 i); specialize (IHdf2 i). - apply (is_derive_ext (fun x0 => ((df_R_vec σ df1 v0 x0 i) + (df_R_vec σ df2 v0 x0 i))%R)). - + intros. - symmetry. - unfold df_R_vec, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial. - destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial. - now rewrite e, e0. - + generalize (@is_derive_plus R_AbsRing); simpl - ; intros HH; now apply HH. - - Case "VectorMinus"%string. - destruct H. - do 2 match_option_in H0. - invcs H0. - specialize (IHdf1 _ σ v0 xx H eqq). - specialize (IHdf2 _ σ v0 xx H1 eqq0). - unfold is_derive_vec; intro. - simpl in *. - unfold is_derive_vec in IHdf1. - unfold is_derive_vec in IHdf2. - specialize (IHdf1 i); specialize (IHdf2 i). - apply (is_derive_ext (fun x0 => ((df_R_vec σ df1 v0 x0 i) - (df_R_vec σ df2 v0 x0 i))%R)). - + intros. - symmetry. - unfold df_R_vec, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial. - destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial. - now rewrite e, e0. - + generalize (@is_derive_minus R_AbsRing); simpl - ; intros HH; now apply HH. - - Case "MatrixPlus"%string. - destruct H. - do 2 match_option_in H0. - invcs H0. - specialize (IHdf1 _ σ v0 xx H eqq). - specialize (IHdf2 _ σ v0 xx H1 eqq0). - unfold is_derive_mat; intros. - simpl in *. - unfold is_derive_mat in IHdf1. - unfold is_derive_mat in IHdf2. - specialize (IHdf1 i j); specialize (IHdf2 i j). - apply (is_derive_ext (fun x0 => ((df_R_mat σ df1 v0 x0 i j) + (df_R_mat σ df2 v0 x0 i j))%R)). - + intros. - symmetry. - unfold df_R_mat, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial. - destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial. - now rewrite e, e0. - + generalize (@is_derive_plus R_AbsRing); simpl - ; intros HH; now apply HH. - - Case "MatrixMinus"%string. - destruct H. - do 2 match_option_in H0. - invcs H0. - specialize (IHdf1 _ σ v0 xx H eqq). - specialize (IHdf2 _ σ v0 xx H1 eqq0). - unfold is_derive_mat; intros. - simpl in *. - unfold is_derive_mat in IHdf1. - unfold is_derive_mat in IHdf2. - specialize (IHdf1 i j); specialize (IHdf2 i j). - apply (is_derive_ext (fun x0 => ((df_R_mat σ df1 v0 x0 i j) - (df_R_mat σ df2 v0 x0 i j))%R)). - + intros. - symmetry. - unfold df_R_mat, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial. - destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial. - now rewrite e, e0. - + generalize (@is_derive_minus R_AbsRing); simpl - ; intros HH; now apply HH. - - Case "VectorScalMult"%string. - do 4 match_option_in H0. - invcs H0. - destruct H. - specialize (IHdf1 _ σ v0 xx H eqq0). - specialize (IHdf2 _ σ v0 xx H0 eqq2). - unfold is_derive_vec; intros. - apply (is_derive_ext (fun x0 => ((df_R σ df1 v0 x0) * (df_R_vec σ df2 v0 x0 i))%R)). - + intro. - unfold df_R_vec, df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial. - destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial. - now rewrite e, e0. - + generalize Derive.is_derive_mult - ; unfold plus, mult; simpl; intros HH. - replace (d * d2 i + d0 * d1 i)%R with - (d0 * (df_R_vec σ df2 v0 xx i) + (df_R σ df1 v0 xx) * (d2 i))%R. - * apply HH; trivial. - simpl in IHdf1; simpl in IHdf2. - unfold is_derive_vec in IHdf2. - apply IHdf2. - * unfold df_R_vec, df_R, df_eval_at_point; simpl. - rewrite eqq; rewrite eqq1. - lra. - - Case "MatrixScalMult"%string. - do 4 match_option_in H0. - invcs H0. - destruct H. - specialize (IHdf1 _ σ v0 xx H eqq0). - specialize (IHdf2 _ σ v0 xx H0 eqq2). - unfold is_derive_mat; intros. - apply (is_derive_ext (fun x0 => ((df_R σ df1 v0 x0) * (df_R_mat σ df2 v0 x0 i j))%R)). - + intro. - unfold df_R_mat, df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial. - destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial. - now rewrite e, e0. - + generalize Derive.is_derive_mult - ; unfold plus, mult; simpl; intros HH. - replace (d * d2 i j + d0 * d1 i j)%R with - (d0 * (df_R_mat σ df2 v0 xx i j) + (df_R σ df1 v0 xx) * (d2 i j))%R. - * apply HH; trivial. - simpl in IHdf1; simpl in IHdf2. - unfold is_derive_mat in IHdf2. - apply IHdf2. - * unfold df_R_mat, df_R, df_eval_at_point; simpl. - rewrite eqq; rewrite eqq1. - lra. - - Case "VectorApply"%string. - destruct H. - do 2 match_option_in H0. - specialize (vectoro_to_ovector_forall_some_f H0); intros. - specialize (IHdf2 d0 σ v0 xx H1 eqq0). - unfold is_derive_vec; intro. - specialize (H2 i); simpl in H2. - match_option_in H2; invcs H2. - simpl in IHdf2; simpl in IHdf1. - specialize (IHdf1 d1). - unfold is_derive_vec in IHdf2; specialize (IHdf2 i). - generalize - (@is_derive_comp R_AbsRing R_NormedModule - (df_R nil df1 v) - (fun x0 => (df_R_vec σ df2 v0 x0 i))); simpl; intros. - specialize (H2 xx d1 (d0 i)). - apply (is_derive_ext (fun x : R => df_R Datatypes.nil df1 v (df_R_vec σ df2 v0 x i))) - ; intros. - + destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial. - unfold df_R_vec, df_eval_at_point; simpl. - rewrite e. - unfold df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (mk_env_entry (v, DTfloat) (x i) :: nil) df1) - ; simpl; trivial. - unfold addBinding. - rewrite e0. - match_option. - * specialize (vectoro_to_ovector_forall_some_f eqq2); intros. - specialize (H3 i);simpl in H3. - congruence. - * apply vectoro_to_ovector_exists_None in eqq2. - destruct eqq2. - destruct (eval_fully_closed_total (mk_env_entry (v, DTfloat) (x x1) :: nil) df1) - ; simpl; trivial. - congruence. - + apply H2. - * unfold df_R_vec, df_R, df_eval_at_point; simpl. - rewrite eqq. - unfold df_R, df_eval_at_point in IHdf1. - specialize (IHdf1 nil v (d i)). - apply IHdf1; trivial. - * apply IHdf2. - - Case "MatrixApply"%string. - destruct H. - do 2 match_option_in H0. - specialize (vectoro_to_ovector_forall_some_f H0); intros. - specialize (IHdf2 d0 σ v0 xx H1 eqq0). - unfold is_derive_mat; intros. - specialize (H2 i); simpl in H2. - unfold matrixo_to_omatrix in H0. - specialize (vectoro_to_ovector_forall_some_f H2); intros. - specialize (H3 j); simpl in H3. - match_option_in H3; invcs H3. - simpl in IHdf2; simpl in IHdf1. - specialize (IHdf1 d1). - unfold is_derive_mat in IHdf2; specialize (IHdf2 i j). - generalize - (@is_derive_comp R_AbsRing R_NormedModule - (df_R nil df1 v) - (fun x0 => (df_R_mat σ df2 v0 x0 i j))) - ; simpl; intros. - specialize (H3 xx d1 (d0 i j)). - apply (is_derive_ext (fun x : R => df_R Datatypes.nil df1 v (df_R_mat σ df2 v0 x i j))) - ; intros. - + destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial. - unfold df_R_mat, df_eval_at_point; simpl. - rewrite e. - unfold matrixo_to_omatrix. - destruct (eval_fully_closed_total (mk_env_entry (v, DTfloat) (x i j) :: nil) df1) - ; simpl; trivial. - unfold df_R, df_eval_at_point, addBinding; simpl. - rewrite e0. - match_option. - * specialize (vectoro_to_ovector_forall_some_f eqq2); intros. - specialize (H4 i); simpl in H4. - specialize (vectoro_to_ovector_forall_some_f H4); intros. - specialize (H6 j); simpl in H6. - congruence. - * apply vectoro_to_ovector_exists_None in eqq2. - destruct eqq2. - apply vectoro_to_ovector_exists_None in e1. - destruct e1. - destruct (eval_fully_closed_total (mk_env_entry (v, DTfloat) (x x1 x2) :: nil) df1) - ; simpl; trivial. - congruence. - + apply H3. - * unfold df_R_mat, df_R, df_eval_at_point; simpl. - rewrite eqq. - specialize (IHdf1 nil v (d i j)). - unfold df_R, df_eval_at_point in IHdf1; simpl in IHdf1. - apply IHdf1; trivial. - * apply IHdf2. - - Case "VLossfun"%string. - destruct H. - do 3 match_option_in H0. - invcs H0. - specialize (vectoro_to_ovector_forall_some_f eqq1); simpl; intros. - specialize (IHdf2 d0 σ v0 xx H1 eqq0). - simpl in IHdf2; simpl in IHdf1. - unfold is_derive_vec in IHdf2. - unfold df_R, df_eval_at_point; simpl. - generalize (is_derive_vsum - (fun x : R => - match df_eval (addBinding σ v0 x) df2 with - | Some l' => - match - vectoro_to_ovector - (fun i : {n' : nat | (n' < n)%nat} => - df_eval - (mk_env_entry (v1, DTfloat) (l' i) - :: mk_env_entry (v2, DTfloat) (r i) - :: Datatypes.nil) df1) - with - | Some vv => vv - | None => fun _ => 0%R - end - | None => fun _ => 0%R - end) - xx v); intros. - apply (is_derive_ext - (fun x0 : R_AbsRing => - vsum - match df_eval (addBinding σ v0 x0) df2 with - | Some l' => - match - vectoro_to_ovector - (fun i : {n' : nat | (n' < n)%nat} => - df_eval - (mk_env_entry (v1, DTfloat) (l' i) - :: mk_env_entry (v2, DTfloat) (r i) - :: Datatypes.nil) df1) - with - | Some vv => vv - | None => fun _ : {n' : nat | (n' < n)%nat} => 0 - end - | None => fun _ : {n' : nat | (n' < n)%nat} => 0 - end)); intros. - + destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial. - rewrite e. - match_option. - apply vsum0. - + apply H2. - unfold is_derive_vec; simpl; intros. - generalize - (@is_derive_comp R_AbsRing R_NormedModule - (df_R (addBinding nil v2 (r i)) df1 v1) - (fun x0 => (df_R_vec σ df2 v0 x0 i))); simpl; intros. - apply (is_derive_ext - (fun x0 : R => - df_R (addBinding Datatypes.nil v2 (r i)) df1 v1 (df_R_vec σ df2 v0 x0 i)));intros. - * unfold df_R_vec, df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial. - rewrite e. - destruct (eval_fully_closed_total - (addBinding (addBinding Datatypes.nil v2 (r i)) v1 (x i)) df1); simpl; trivial. - rewrite e0. - match_option. - -- specialize (vectoro_to_ovector_forall_some_f eqq2); intros. - specialize (H4 i); simpl in H4. - unfold addBinding in e0. - congruence. - -- apply vectoro_to_ovector_exists_None in eqq2. - destruct eqq2. - destruct (eval_fully_closed_total - (mk_env_entry (v1, DTfloat) (x x1) - :: mk_env_entry (v2, DTfloat) (r x1) - :: Datatypes.nil) df1); trivial. - congruence. - * specialize (H0 i). - match_option_in H0. - inversion H0. - apply H3; trivial. - apply IHdf1; trivial. - unfold addBinding, df_R_vec, df_eval_at_point; simpl. - rewrite eqq. - apply eqq2. - - Case "MLossfun"%string. - destruct H. - do 3 match_option_in H0. - invcs H0. - destruct (Nat.eq_dec n 0); [congruence|]. - invcs H3. - specialize (vectoro_to_ovector_forall_some_f eqq1); simpl; intros. - specialize (IHdf2 d0 σ v0 xx H1 eqq0). - simpl in IHdf2; simpl in IHdf1. - unfold is_derive_mat in IHdf2. - replace ((@msum floatish_R m n m0) / IZR (Z.of_nat n))%R with (/ (IZR (Z.of_nat n))%R * msum m0)%R by lra. - apply (is_derive_ext (fun x0 => scal (Rinv (IZR (Z.of_nat n))%R) - (scal (IZR (Z.of_nat n))%R - (df_R σ (MLossfun ann v1 v2 df1 df2 r) v0 x0)))); intros. - + unfold scal; simpl. - unfold mult; simpl. - field. - apply IZR_neq; lia. - + apply is_derive_scal. - unfold df_R, df_eval_at_point; simpl. - generalize (is_derive_msum - (fun x : R => - match df_eval (addBinding σ v0 x) df2 with - | Some l' => - match - matrixo_to_omatrix - (fun (i : {n' : nat | (n' < m)%nat}) - (j : {m' : nat | (m' < n)%nat}) => - df_eval - (mk_env_entry (v1, DTfloat) (l' i j) - :: mk_env_entry (v2, DTfloat) (r i j) - :: Datatypes.nil) df1) - with - | Some vv => vv - | None => fun _ _ => 0%R - end - | None => fun _ _ => 0%R - end) - xx m0); intros. - apply (is_derive_ext - (fun x0 : R_AbsRing => - msum - match df_eval (addBinding σ v0 x0) df2 with - | Some l' => - match - matrixo_to_omatrix - (fun (i : {n' : nat | (n' < m)%nat}) - (j : {m' : nat | (m' < n)%nat}) => - df_eval - (mk_env_entry (v1, DTfloat) (l' i j) - :: mk_env_entry (v2, DTfloat) (r i j) - :: Datatypes.nil) df1) - with - | Some vv => vv - | None => - fun (_ : {n' : nat | (n' < m)%nat}) - (_ : {m' : nat | (m' < n)%nat}) => 0 - end - | None => fun (_ : {n' : nat | (n' < m)%nat}) - (_ : {m' : nat | (m' < n)%nat}) => 0 - end)); intros. - * destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial. - rewrite e. - unfold matrixo_to_omatrix; simpl. - match_option. - -- unfold scal; simpl. - unfold mult; simpl. - generalize (msum v); unfold float; simpl; intros. - field. - rewrite <- INR_IZR_INZ. - now apply INR_nzero_eq. - -- unfold scal; simpl. - unfold mult; simpl. - replace (IZR (Z.of_nat n) * 0)%R with (0)%R by lra. - unfold msum. - erewrite vsum_ext; try apply vsum0; intro. - rewrite vmap_nth. - erewrite vsum_ext; try apply vsum0; intro. - trivial. - * apply H2. - unfold is_derive_mat; simpl; intros. - generalize - (@is_derive_comp R_AbsRing R_NormedModule - (df_R (addBinding nil v2 (r i j)) df1 v1) - (fun x0 => (df_R_mat σ df2 v0 x0 i j))); simpl; intros. - apply (is_derive_ext - (fun x0 : R => - df_R (addBinding Datatypes.nil v2 (r i j)) df1 v1 - (df_R_mat σ df2 v0 x0 i j)));intros. - -- unfold df_R_mat, df_R, df_eval_at_point; simpl. - destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial. - rewrite e. - destruct (eval_fully_closed_total - (addBinding (addBinding Datatypes.nil v2 (r i j)) v1 (x i j)) df1) - ; simpl; trivial. - rewrite e0. - match_option. - ++ specialize (vectoro_to_ovector_forall_some_f eqq2); intros. - specialize (H4 i); simpl in H4. - specialize (vectoro_to_ovector_forall_some_f H4); intros. - specialize (H5 j); simpl in H5. - unfold addBinding in e0. - congruence. - ++ unfold matrixo_to_omatrix in eqq2. - apply vectoro_to_ovector_exists_None in eqq2. - destruct eqq2. - apply vectoro_to_ovector_exists_None in e1. - destruct e1. - destruct (eval_fully_closed_total - (mk_env_entry (v1, DTfloat) (x x1 x2) - :: mk_env_entry (v2, DTfloat) (r x1 x2) - :: Datatypes.nil) df1); trivial. - congruence. - -- specialize (H0 i). - specialize (vectoro_to_ovector_forall_some_f H0); intros. - specialize (H4 j); simpl in H4. - match_option_in H4. - inversion H4. - apply H3; trivial. - apply IHdf1; trivial. - unfold addBinding, df_R_vec, df_eval_at_point; simpl. - unfold df_R_mat, df_eval_at_point; simpl. - rewrite eqq. - apply eqq2. - Qed. diff --git a/coq/utils/ExtrFloatishIEEE.v b/coq/utils/ExtrFloatishIEEE.v deleted file mode 100644 index ad4a798a..00000000 --- a/coq/utils/ExtrFloatishIEEE.v +++ /dev/null @@ -1,33 +0,0 @@ -Require Import Extraction. -Require Import FloatishIEEE. - -(* This is assumed by fromZ *) -Require Import ExtrOcamlZInt. - -Extract Constant IEEE_float => "float". - -Extract Inlined Constant IEEE_zero => "0.". -Extract Inlined Constant IEEE_opp => "Float.neg". -Extract Inlined Constant IEEE_plus => "Float.add". -Extract Inlined Constant IEEE_minus => "Float.sub". -Extract Inlined Constant IEEE_mult => "Float.mul". -Extract Inlined Constant IEEE_div => "Float.div". -Extract Inlined Constant IEEE_sqrt => "Float.sqrt". -Extract Inlined Constant IEEE_abs => "Float.abs". - - -Extract Inlined Constant IEEE_exp => "Float.exp". -Extract Inlined Constant IEEE_ln => "Float.log". - -Extract Inlined Constant IEEE_pi => "Float.pi". -Extract Inlined Constant IEEE_sin => "Float.sin". -Extract Inlined Constant IEEE_cos => "Float.cos". - -Extract Inlined Constant IEEE_fromZ => "Float.of_int". - -Extract Inlined Constant IEEE_eq => "Float.equal". -Extract Inlined Constant IEEE_neq => "(fun x y -> x <> y)". -Extract Inlined Constant IEEE_lt => "(fun x y -> x < y)". -Extract Inlined Constant IEEE_le => "(fun x y -> x <= y)". -Extract Inlined Constant IEEE_gt => "(fun x y -> x > y)". -Extract Inlined Constant IEEE_ge => "(fun x y -> x >= y)". diff --git a/coq/utils/Floatish.v b/coq/utils/Floatish.v deleted file mode 100644 index f3ee3f12..00000000 --- a/coq/utils/Floatish.v +++ /dev/null @@ -1,8 +0,0 @@ -Require Export FloatishDef. -Require Export FloatishOps. - -Require Export FloatishInterval. -Require Export FloatishIEEE. -Require Export FloatishReal. - -Require Export FloatishRealOps. diff --git a/coq/utils/Floatish/FloatishDef.v b/coq/utils/Floatish/FloatishDef.v deleted file mode 100644 index b244525e..00000000 --- a/coq/utils/Floatish/FloatishDef.v +++ /dev/null @@ -1,52 +0,0 @@ -Require Import BinInt. - -Declare Scope float. - -Class floatish : Type := - { - float : Type - ; Fzero : float - - ; Fopp : float -> float - - ; Fplus : float -> float -> float - ; Fminus : float -> float -> float - ; Fmult : float -> float -> float - ; Fdiv : float -> float -> float - - ; Fsqrt : float -> float - ; Fabs : float -> float - - ; Fexp : float -> float - ; Fln : float -> float - - ; Fsin : float -> float - ; Fcos : float -> float - ; Fpi : float - - ; FfromZ : Z -> float - - ; Feq : float -> float -> bool - ; Fneq : float -> float -> bool - ; Flt : float -> float -> bool - ; Fle : float -> float -> bool - ; Fgt : float -> float -> bool - ; Fge : float -> float -> bool - }. - -Notation "0" := (Fzero) : float. -Notation "1" := (FfromZ 1) : float. -Notation "2" := (FfromZ 2) : float. -Notation "- x" := (Fopp x) (at level 35, right associativity) : float. -Notation "x + y" := (Fplus x y) (at level 50, left associativity) : float. -Notation "x - y" := (Fminus x y) (at level 50, left associativity) : float. -Notation "x * y" := (Fmult x y) (at level 40, left associativity) : float. -Notation "x / y" := (Fdiv x y) (at level 40, left associativity) : float. - - -Notation "x ==b y" := (Feq x y) (at level 70, no associativity) : float. -Notation "x != y" := (Fneq x y) (at level 70, no associativity) : float. -Notation "x < y" := (Flt x y) (at level 70, no associativity) : float. -Notation "x <= y" := (Fle x y) (at level 70, no associativity) : float. -Notation "x > y" := (Fgt x y) (at level 70, no associativity) : float. -Notation "x >= y" := (Fge x y) (at level 70, no associativity) : float. diff --git a/coq/utils/Floatish/FloatishIEEE.v b/coq/utils/Floatish/FloatishIEEE.v deleted file mode 100644 index 91a82a50..00000000 --- a/coq/utils/Floatish/FloatishIEEE.v +++ /dev/null @@ -1,98 +0,0 @@ -Require Import Flocq.IEEE754.BinarySingleNaN. -Require Import BinInt. - -Require Import FloatishDef. - -Require Import Flocq.IEEE754.Binary. -Require Import Flocq.IEEE754.Bits. - - -Definition IEEE_float := binary64. -Definition IEEE_zero : IEEE_float := B754_zero 53 1024 false. -Definition IEEE_opp := b64_opp. -Definition IEEE_plus := b64_plus mode_NE. -Definition IEEE_minus := b64_minus mode_NE. -Definition IEEE_mult := b64_mult mode_NE. -Definition IEEE_div := b64_div mode_NE. -Definition IEEE_sqrt := b64_sqrt mode_NE. -Definition IEEE_abs := b64_abs. - -Definition IEEE_fromZ i := binary_normalize 53 1024 (eq_refl _) (eq_refl _) mode_NE i 0 false. - -Definition IEEE_eq (x y:IEEE_float) - := (match b64_compare x y with - | Some Eq => true - | _ => false - end). - -Definition IEEE_neq (x y:IEEE_float) - := (match b64_compare x y with - | Some Eq => false - | Some _ => true - | _ => false - end). - -Definition IEEE_lt (x y:IEEE_float) - := (match b64_compare x y with - | Some Lt => true - | _ => false - end). - -Definition IEEE_le (x y:IEEE_float) - := (match b64_compare x y with - | Some Lt => true - | Some Eq => true - | _ => false - end). - - -Definition IEEE_gt (x y:IEEE_float) - := (match b64_compare x y with - | Some Gt => true - | _ => false - end). - - -Definition IEEE_ge (x y:IEEE_float) - := (match b64_compare x y with - | Some Gt => true - | Some Eq => true - | _ => false - end). - -(* following function will be defined only via extraction *) -Axiom IEEE_exp : IEEE_float -> IEEE_float. -Axiom IEEE_ln : IEEE_float -> IEEE_float. -Axiom IEEE_pi : IEEE_float. -Axiom IEEE_sin : IEEE_float -> IEEE_float. -Axiom IEEE_cos : IEEE_float -> IEEE_float. - -Local Instance floatish_IEEE : floatish := - { - float := IEEE_float - ; Fzero := IEEE_zero - ; Fopp := IEEE_opp - ; Fplus := IEEE_plus - ; Fminus := IEEE_minus - ; Fmult := IEEE_mult - ; Fdiv := IEEE_div - ; Fsqrt := IEEE_sqrt - ; Fabs := IEEE_abs - - ; Fexp := IEEE_exp - ; Fln := IEEE_ln - - ; Fpi := IEEE_pi - ; Fsin := IEEE_sin - ; Fcos := IEEE_cos - - ; FfromZ := IEEE_fromZ - - - ; Feq := IEEE_eq - ; Fneq := IEEE_neq - ; Flt := IEEE_lt - ; Fle := IEEE_le - ; Fgt := IEEE_gt - ; Fge := IEEE_ge - }. diff --git a/coq/utils/Floatish/FloatishInterval.v b/coq/utils/Floatish/FloatishInterval.v deleted file mode 100644 index 0c0516e3..00000000 --- a/coq/utils/Floatish/FloatishInterval.v +++ /dev/null @@ -1,82 +0,0 @@ -Require Import BinInt. - -Require Import Interval.Real.Xreal. - -Require Import Interval.Interval.Transcend. -Require Import Interval.Float.Specific_ops. -Require Import Interval.Float.Specific_stdz. -Require Import Interval.Float.Basic. - -Require Import FloatishDef. - -Module F := SpecificFloat StdZRadix2. -Module A := TranscendentalFloatFast F. - -Local Instance floatish_interval_gen (prec:Z) : floatish := - { - float := F.type - ; Fzero := F.zero - - ; Fopp := F.neg - - ; Fplus := F.add_slow rnd_NE prec - ; Fminus x y := F.add_slow rnd_NE prec x (F.neg y) - ; Fmult := F.mul rnd_NE prec - ; Fdiv := F.div rnd_NE prec - - ; Fsqrt := F.sqrt rnd_NE prec - ; Fabs := F.abs - - ; Fexp x := A.I.midpoint(A.exp_fast prec x) - ; Fln x := A.I.midpoint(A.ln_fast prec x) - - ; Fsin x := A.I.midpoint(A.sin_fast prec x) - ; Fcos x := A.I.midpoint(A.cos_fast prec x) - ; Fpi := F.mul rnd_NE prec (F.fromZ 4) (A.I.midpoint (A.pi4 prec)) - - ; FfromZ := F.fromZ - - ; Feq (x y:F.type) - := (match F.cmp x y with - | Xeq => true - | _ => false - end) - - ; Fneq (x y:F.type) - := (match F.cmp x y with - | Xeq => false - | _ => true - end) - - ; Flt (x y:F.type) - := (match F.cmp x y with - | Xlt => true - | _ => false - end) - - ; Fle (x y:F.type) - := (match F.cmp x y with - | Xlt => true - | Xeq => true - | _ => false - end) - - ; Fgt (x y:F.type) - := (match F.cmp x y with - | Xgt => true - | _ => false - end) - - ; Fge (x y:F.type) - := (match F.cmp x y with - | Xgt => true - | Xeq => true - | _ => false - end) - }. - - -Local Instance floatish_interval : floatish := floatish_interval_gen 53. - -Definition FZF (r:float) := F.nearbyint rnd_NE r. -Definition FZFscale (n:Z) (r:float) := FZF (Fmult (FfromZ n) r). diff --git a/coq/utils/Floatish/FloatishOps.v b/coq/utils/Floatish/FloatishOps.v deleted file mode 100644 index b01f1a39..00000000 --- a/coq/utils/Floatish/FloatishOps.v +++ /dev/null @@ -1,25 +0,0 @@ -Require Import FloatishDef. - -Section floatish_ops. - - Context {floatish_impl:floatish}. - Local Open Scope float. - - Definition pos_sign (e:float) - := if e >= 0 then 1 else Fopp 1. - - Definition neg_sign (e:float) - := if e <= 0 then Fopp 1 else 1. - - Definition sign (e:float) - := if e < 0 then Fopp 1 - else if e > 0 then 1 - else 0. - - Definition Fmax (x y:float) - := if x < y then y else x. - - Definition Fmin (x y:float) - := if x > y then y else x. - -End floatish_ops. diff --git a/coq/utils/Floatish/FloatishReal.v b/coq/utils/Floatish/FloatishReal.v deleted file mode 100644 index b7bfde15..00000000 --- a/coq/utils/Floatish/FloatishReal.v +++ /dev/null @@ -1,33 +0,0 @@ -Require Import BinInt Reals Lra. - -Require Import FloatishDef. - -Local Instance floatish_R : floatish := - { - float := R - ; Fzero := 0%R - ; Fopp := Ropp - ; Fplus := Rplus - ; Fminus := Rminus - ; Fmult := Rmult - ; Fdiv := Rdiv - ; Fsqrt := sqrt - ; Fabs := Rabs - - ; Fexp := exp - ; Fln := ln - - ; Fpi := PI - ; Fsin := sin - ; Fcos := cos - - ; FfromZ := IZR - - - ; Feq x y := if Req_EM_T x y then true else false - ; Fneq x y := if Req_EM_T x y then false else true - ; Flt x y := if Rlt_dec x y then true else false - ; Fle x y := if Rle_dec x y then true else false - ; Fgt x y := if Rgt_dec x y then true else false - ; Fge x y := if Rge_dec x y then true else false - }. diff --git a/coq/utils/Floatish/FloatishRealOps.v b/coq/utils/Floatish/FloatishRealOps.v deleted file mode 100644 index 3a20978a..00000000 --- a/coq/utils/Floatish/FloatishRealOps.v +++ /dev/null @@ -1,21 +0,0 @@ -Require Import BinInt Reals Lra. -Require Import FloatishDef FloatishReal FloatishOps. - -Section real_pfs. - - Local Existing Instance floatish_R. - Lemma Fmax_Rmax x y : Fmax x y = Rmax x y. - Proof. - vm_compute. - destruct (Rlt_dec x y); destruct (Rle_dec); lra. - Qed. - - Lemma Fmin_Rmin x y : Fmin x y = Rmin x y. - Proof. - vm_compute. - destruct (Rgt_dec x y); destruct (Rle_dec); lra. - Qed. - -End real_pfs. - -Hint Rewrite Fmax_Rmax Fmin_Rmin : Rarith. diff --git a/coq/utils/nvector.v b/coq/utils/nvector.v deleted file mode 100644 index 8e14a0d7..00000000 --- a/coq/utils/nvector.v +++ /dev/null @@ -1,282 +0,0 @@ -Require Import List. -Require Import BinInt. -Require Import Lia. -Require Import LibUtils. - -Require Import VectorDef. -Require Vector. - -Section Vector. - -Definition vector (T:Type) (n:nat) := Vector.t T n. - -Definition vnil {T} : vector T 0 := nil T. - -Definition vcons {T} (n:nat) (c:T) (v:vector T n) : vector T (S n) := - cons T c n v. - -Definition vappend {T} (n1 n2:nat) (v1:vector T n1) (v2:vector T n2) : vector T (n1 + n2) - := append v1 v2. - -Definition vmap {A B} {n} (f:A->B) (v : vector A n) : vector B n := map f v. - -Definition vhd {T} {n:nat} (v : vector T (S n)):T := hd v. - -Definition vtl {T} {n:nat} (v : vector T (S n)) : vector T n := tl v. - -Definition vlast {T} {n:nat} (v : vector T (S n)) := last v. - -Definition vnth {T} {n:nat} (v : vector T n) (i:nat | i T := - fun i => vnth v i. - -Program Definition ConstVector {T} (n:nat) (c:T) : vector T n - := of_list (repeat c n). -Next Obligation. - now rewrite repeat_length. -Qed. - -Program Definition build_vector {T} {n:nat} (v:{n':nat | n' < n}%nat -> T) : vector T n - := of_list (Vector.vector_to_list v). -Next Obligation. - apply Vector.vector_to_list_length. -Qed. - -Lemma to_list_length {T} {n:nat} (v : vector T n) : length (to_list v) = n. - induction v; simpl; trivial. - now f_equal. -Qed. - -Program Definition vcombine {T1 T2} {n:nat} (v1:vector T1 n) (v2:vector T2 n): vector (T1*T2) n := - of_list (combine (to_list v1) (to_list v2)). -Next Obligation. - rewrite combine_length. - rewrite to_list_length, to_list_length. - apply PeanoNat.Nat.min_id. -Qed. - -Definition vector_zip {T1 T2} {n:nat} (v1:vector T1 n) (v2:vector T2 n): vector (T1*T2) n := - vcombine v1 v2. - -Definition vmap2 {A B C} {n} (f:A->B->C) (v1 : vector A n) (v2 : vector B n) : vector C n - := Vector.map2 f v1 v2. - -Definition vmap4 {A B} {n} (f:A->A->A->A->B) (v1 v2 v3 v4 : vector A n) : vector B n := - vmap2 (fun '(a1,a2) '(a3,a4) => f a1 a2 a3 a4) (vcombine v1 v2) (vcombine v3 v4). - -Program Definition vectoro_to_ovector {T} {n} (v:vector (option T) n) : option (vector T n) - := match listo_to_olist (to_list v) with - | None => None - | Some l => Some (of_list l) - end. -Next Obligation. - symmetry in Heq_anonymous. - apply listo_to_olist_some in Heq_anonymous. - rewrite <- map_length with (f := Some). - rewrite <- Heq_anonymous. - now apply to_list_length. -Qed. - -Definition vforall {A} {m:nat} (P: A -> Prop) (v:vector A m) : Prop - := Vector.Forall P v. - -End Vector. - -Section Matrix. - -Definition matrix (T:Type) (n m : nat) := vector (vector T m) n. - -Definition mat_fun {T:Type} (n m : nat) (mat : matrix T n m ) : - {n':nat | n' < n}%nat -> {m':nat | m' < m}%nat -> T := - fun i => fun j => vnth (vnth mat i) j. - -Definition mmap {A B} {n m} (f:A->B) (mat : matrix A n m) : matrix B n m := - vmap (vmap f) mat. - -Definition mnth {T} {n m :nat} (v : matrix T n m) (i:nat | i vcombine a b) (vcombine mat1 mat2). - -Definition matrix_zip {T1 T2} {n m : nat} (mat1 : matrix T1 n m) (mat2 : matrix T2 n m) : matrix (T1*T2) n m := mcombine mat1 mat2. - -Definition build_matrix {T} {n m:nat} - (mat:{n':nat | n' < n}%nat -> {m':nat | m' < m}%nat -> T) : matrix T n m - := vmap build_vector (build_vector mat). - -Definition transpose {T} {m n : nat} (mat:matrix T m n) : matrix T n m - := build_matrix (fun i j => mnth mat j i). - -Definition ConstMatrix {T} (n m : nat) (c:T) : matrix T n m := - ConstVector n (ConstVector m c). - -Definition matrixo_to_omatrix {T} {m n} (v:matrix (option T) m n) : option (matrix T m n) - := vectoro_to_ovector (vmap vectoro_to_ovector v). - -Definition mmap2 {A B C} {n m} (f:A->B->C) (v1 : matrix A n m) (v2 : matrix B n m) : matrix C n m := vmap2 (fun r1 r2 => vmap2 f r1 r2) v1 v2. - -Definition mforall {A} {m n:nat} (P: A -> Prop) (m:matrix A m n) : Prop - := vforall (fun x => vforall P x) m. - -End Matrix. - -Section Tensor. -Fixpoint tensor T (l:list nat) : Type - := match l with - | List.nil => T - | x::l' => vector (tensor T l') x - end. - -Lemma tensor0 T : tensor T List.nil = T. -Proof. - reflexivity. -Qed. - -Lemma tensor1 T n : tensor T (n::List.nil) = vector T n. -Proof. - reflexivity. -Qed. - -Lemma tensor_app T l1 l2 : tensor (tensor T l1) l2 = tensor T (l2++l1). -Proof. - revert l1. - induction l2; intros l1; simpl; trivial. - now rewrite IHl2. -Qed. - -Fixpoint ConstTensor {T} (l : list nat) (c:T) : (tensor T l) := - match l with - | List.nil => c - | x::l' => ConstVector x (ConstTensor l' c) - end. - -Fixpoint Tensor_map {A B} {dims:list nat} (f:A->B) : tensor A dims -> tensor B dims - := match dims with - | List.nil => fun x => f x - | x::l' => vmap (Tensor_map f) - end. - -Definition scalar {T} (c:T) : tensor T List.nil := c. - -End Tensor. - -Inductive NumericType - := FloatType - | IntType. - - -Definition ntype_interp (n:NumericType) : Type - := match n with - | FloatType => nat - | IntType => Z - end. - - Structure BigArray (ln : list nat) (T : NumericType) : Type := - Tensor { tdata :> list (ntype_interp T); _ : length tdata = List.fold_right Nat.mul 1%nat ln }. - - Structure Array1 (n : nat) (T : NumericType) : Type := - array1 { a1data :> list (ntype_interp T); _ : length a1data = n}. - - Structure Array2 (n m : nat) (T : NumericType) : Type := - array2 { a2data :> list (ntype_interp T); _ : length a2data = n * m}. - -Definition tensor_abs_type (T:NumericType) (dims:list nat) := tensor (ntype_interp T) dims. - -Class TensorDef := - { - tensor_t (T:NumericType) (dims:list nat) : Type - ; tensor_repr {T:NumericType} {dims:list nat} : tensor_t T dims -> tensor_abs_type T dims -> Prop - - ; tensor_const {T} (dims : list nat) (c:ntype_interp T) : tensor_t T dims - ; tensor_const_p {T} (dims : list nat) (c:ntype_interp T) : tensor_repr (tensor_const dims c) (ConstTensor dims c) - - ; tensor_map {A B} {dims : list nat} (f:ntype_interp A-> ntype_interp B) (t:tensor_t A dims) : tensor_t B dims - ; tensor_map_p {A B} {dims : list nat} (f:ntype_interp A-> ntype_interp B) (t:tensor_t A dims) : - forall r, tensor_repr t r -> - tensor_repr (tensor_map f t) (Tensor_map f r) - - (* ; tensor_nth {A} {dims : list nat} (indices:list nat) (indices_in_range:True) (t:tensor A dims) : A *) - (* ; tensor_nth_p {A:Type} {dims : list nat} (indices:list nat) (indices_in_range:True) (t:tensor_t A dims) : *) - (* forall r, tensor_repr t r -> *) - (* tensor_repr (tensor_nth indices indices_in_range t) (tensor_nth indices indices_in_range r) *) - }. - -(* -Class TensorDefExt {base:TensorDef} := - { - tensor_transpose; - }. -*) - (* ; tensor_nth {A} {dims : list nat} (indices:list nat) (indices_in_range:True) (t:tensor A dims) : A *) - (* ; tensor_nth_p {A:Type} {dims : list nat} (indices:list nat) (indices_in_range:True) (t:tensor_t A dims) : *) - (* forall r, tensor_repr t r -> *) - (* tensor_repr (tensor_nth indices indices_in_range t) (tensor_nth indices indices_in_range r) *) - -(* Instance trivial_TensorDef : TensorDef := *) -(* { *) -(* tensor_t := tensor; *) -(* tensor_repr _ _ a b := a = b *) -(* }. *) -(* -Fixpoint flat_list_represent_tensor {T} {dims} (l:list A) (t:tensor T dims) : Prop - := - -Instance BigArray_TensorDef : TensorDef - := { - tensor_t A dims := list A; - tensor_repr T dims (l:list A) (tensor T dims) - := fix - }. -*) - -Require Import Floatish. -Section float_ops. - Context {floatish_impl:floatish}. - Local Open Scope float. - - Definition vsum {m:nat} (v:vector float m) : float - := List.fold_right Fplus 0 (to_list v). - - Definition msum {m n:nat} (v:matrix float m n) : float := - vsum (vmap vsum v). - - Definition vdot {m:nat} (v1 v2 : vector float m) : float := - List.fold_right Fplus 0 - (List.map (fun '(a,b) => a * b) - (combine (to_list v1) (to_list v2))). - - Definition vadd {m:nat} (v1 v2 : vector float m) := - vmap (fun '(a,b) => a+b) (vcombine v1 v2). - - Definition madd {m n:nat} (mat1 mat2 : matrix float m n) := - mmap (fun '(a,b) => a+b) (mcombine mat1 mat2). - - Definition matrix_vector_mult {n m} (l : matrix float n m)(r : vector float m) : vector float n := - vmap (fun l1 => vdot l1 r) l. - - Definition matrix_vector_add {n m} (l : matrix float n m) (r : vector float n) : matrix float n m := - build_matrix (fun i j => (vnth (vnth l i) j) + (vnth r i)). - -(* - transpose (vmap (fun l1 => vadd l1 r) (transpose l)). - *) - - Definition matrix_mult {n m p} (l : matrix float n m)(r : matrix float m p) : matrix float n p := - build_matrix (fun i k => vsum (build_vector - (fun j => (vnth (vnth l i) j) * - (vnth (vnth r j) k)))). - -(* - transpose (vmap (fun r1 => matrix_vector_mult l r1) (transpose r)). -*) - -End float_ops. - - - - - - diff --git a/ocaml/Makefile b/ocaml/Makefile deleted file mode 100644 index 18179501..00000000 --- a/ocaml/Makefile +++ /dev/null @@ -1,78 +0,0 @@ -# -# Copyright 2015-2016 IBM Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# User-level configuration -# include ../Makefile.config -# Contains the list of all the Coq modules -include ../Makefile.coq_modules - -## Configuraton -NNOPT_HOME=$(CURDIR)/.. - -############# Shouldn't have to be changed after this -OCAMLBUILD= ocamlbuild \ - -no-links -classic-display \ - -tags annot -use-ocamlfind -package unix -package base64 -package csv - -MENHIRFLAG=-use-menhir -#MENHIRFLAG= - -## Mains -MAIN=nnopt - -TARGET=native - -## Toplevel -all: ../bin/$(MAIN) - -native: ../bin/$(MAIN) - -## Extraction -VO_FILES = $(MODULES:%=../coq/%.vo) - -extracted: extracted/StaticConfig.ml extracted/NnoptExtracted.ml extracted/NnoptExtracted.mli - -extracted/StaticConfig.ml extracted/NnoptExtracted.ml extracted/NnoptExtracted.mli : $(VO_FILES) NnoptExtraction.v - rm -rf extracted - mkdir -p extracted - echo "(* This file is generated *)" > extracted/StaticConfig.ml - echo "let nnopt_home = \"$(NNOPT_HOME)\"" >> extracted/StaticConfig.ml - (coqc -R ../coq FormalML NnoptExtraction.v) - -## Native -../bin/$(MAIN): extracted $(MAIN).$(TARGET) ../bin - cp _build/$(MAIN).$(TARGET) ../bin/$(MAIN) - -../bin: - @mkdir -p ../bin - -$(MAIN).$(TARGET): extracted - $(OCAMLBUILD) $(MENHIRFLAG) -Is extracted -Is src $(MAIN).$(TARGET) - -## Clean - -clean: - ocamlbuild -clean -no-log - rm -rf _build - rm -f ../bin/$(MAIN) - -cleanall: clean - rm -f NnoptExtraction.glob NnoptExtraction.vo .NnoptExtraction.aux - rm -rf extracted - rm -rf *~ - -.NOTPARALLEL: - diff --git a/ocaml/NnoptExtraction.v b/ocaml/NnoptExtraction.v deleted file mode 100644 index 08a3073b..00000000 --- a/ocaml/NnoptExtraction.v +++ /dev/null @@ -1,24 +0,0 @@ -(* Configuration of the extraction *) - -Require Extraction. -Extraction Language OCaml. -Require Import ExtrOcamlBasic ExtrOcamlString ExtrOcamlZInt ExtrOcamlNatInt. -Extraction Blacklist String List. - -Require Import FloatishIEEE. -Require Import ExtrFloatishIEEE. - -(* Require Import ExtrR. *) -(* Our stuff modules *) - -Require API. - -(* Workaround for https://github.com/coq/coq/issues/13288 , suggested by a comment on the issue. - Coq extraction currently creates a clash between the extracted Decimal.int and the - ocaml int type. -*) -Extract Inductive Decimal.int => unit [ "(fun _ -> ())" "(fun _ -> ())" ] "(fun _ _ _ -> assert false)". - -Cd "./extracted". - -Recursive Extraction Library API. diff --git a/ocaml/nnopt.ml b/ocaml/nnopt.ml deleted file mode 100644 index e2f1cbd6..00000000 --- a/ocaml/nnopt.ml +++ /dev/null @@ -1,74 +0,0 @@ -open Format - -open API -open Util -open Pretty - - -let () = Format.printf "Result of running opt: %a\n" (pretty_visible_option pretty_df_env) API.opt ;; -let () = Format.printf "Result of running opt2: %a\n" (pretty_visible_option pretty_df_env) API.opt2 ;; -let () = Format.printf "The testopt environment: %a\n" pretty_df_env API.testopt ;; -let () = Format.printf "The testreeopt environment: %a\n" pretty_df_env API.testreeopt ;; - -let () = Format.printf "The gradenv environment: %a\n" pretty_df_env API.gradenv ;; -let () = Format.printf "The gradenv_tree environment: %a\n" pretty_df_env API.gradenv_tree ;; - -let () = Format.printf "The test_update environment: %a\n" pretty_df_env API.test_update ;; - -let () = Format.printf "The test environment: %a\n" pretty_df_env API.test_env ;; - -let data = read_int_matrix_from_csv "breast-cancer-wisconsin.data" ;; -let actual_data = API.discard_first data ;; - -let () = Format.printf "first part of data without the first column: %d\n" (List.hd (List.hd actual_data)) -let normalized_data = API.normalizeIntData actual_data ;; -let () = Format.printf "first 10 rows of normalized data without the first column: \n%a\n" ( pretty_matrix 10 10) normalized_data ;; - -let () = Random.self_init() - -let randomStream = mkIndexedStream 0 (Obj.magic (API.random_float_vector ())) ;; -let fvals = fst(streamtake 5 randomStream) ;; -let () = Format.printf "random list : %a\n" (pretty_blist pp_print_float) fvals ;; - -let init_env = init_env2 9 15 1 (char_list_of_string "w") (char_list_of_string "b") - (Obj.magic (random_float_matrix ())) (Obj.magic (random_float_matrix ())) ;; -let () = Format.printf "Init environment: %a\n" pretty_df_env init_env ;; - -let wval = eval_wisconsin_batch 10 (Obj.magic init_env) (Obj.magic normalized_data) ;; -let () = Format.printf "wisconsin init loss value : %a\n" (pretty_blist pp_print_float) (Obj.magic wval) ;; - -let wval2 = wisconsin_test 10 100 (Obj.magic init_env) (Obj.magic normalized_data) ;; -let () = Format.printf "wisconsin loss value : %a\n" (pretty_blist pp_print_float) (Obj.magic wval2) ;; -(* -let wenv = wisconsin_test_env 6 10 (Obj.magic init_env) (Obj.magic normalized_data) ;; -let () = Format.printf "wisconsin test env: %a\n" pretty_df_env wenv ;; - -let nnval = nn_test_val ;; -let () = Format.printf "nn test init loss value : %a\n" (pretty_blist pp_print_float) (Obj.magic nnval) ;; - -let wval3 = nn_test 1 ;; -let () = Format.printf "NN test loss value : %a\n" (pretty_blist pp_print_float) (Obj.magic wval3) ;; - -let wenv3 = nn_test_env 1 ;; -let () = Format.printf "NN test env: %a\n" pretty_df_env wenv3 ;; - -let wenv4 = nn_test_gradenv ;; -let () = Format.printf "NN gradenv env: %a\n" pretty_df_env wenv4 ;; - -let wenv5 = nn_test_gradenv_tree ;; -let () = Format.printf "NN gradenv env tree: %a\n" pretty_df_env wenv5 ;; -*) -(* -let gradenvtree = wisconsin_gradenv_tree 1 (Obj.magic init_env) (Obj.magic normalized_data) ;; -let () = Format.printf "wisconsin gradenv_tree : %a\n" pretty_df_env gradenvtree ;; - -let gradenv = wisconsin_gradenv 1 (Obj.magic init_env) (Obj.magic normalized_data) ;; -let () = Format.printf "wisconsin gradenv : %a\n" pretty_df_env gradenv ;; -*) - - - - - - - diff --git a/ocaml/src/Pretty.ml b/ocaml/src/Pretty.ml deleted file mode 100644 index 8259be20..00000000 --- a/ocaml/src/Pretty.ml +++ /dev/null @@ -1,70 +0,0 @@ -open Format - -open API - -let rec subVar_to_list sv = - begin match sv with - | Name s -> (Util.string_of_char_list s, []) - | Sub (v,i) -> let (s,r) = subVar_to_list v in - (s, i::r) - end - -let pretty_const_string s ff _ = pp_print_string ff s - -let pretty_blist ?(bstart="[") ?(bend="]") ?(bsep=",") pp ff l = - pp_print_string ff bstart ; - (pp_print_list ~pp_sep:(pretty_const_string bsep)) pp ff l ; - pp_print_string ff bend - -let pretty_subVar ff sv = - let (s,l) = subVar_to_list sv in - pp_print_string ff s ; - if l <> [] - then pretty_blist pp_print_int ff l - -let pretty_definition_function_types ff dft = - begin match dft with - | DTfloat -> fprintf ff "%s" "float" - | DTVector m -> fprintf ff "%s[%d]" "float" m - | DTMatrix (m,n) -> fprintf ff "%s[%d,%d]" "float" m n - end - -let pretty_var_type ff (sv, dft) = - fprintf ff "%a{%a}" pretty_subVar sv pretty_definition_function_types dft - -let pretty_vector n ff v = - let fs = List.init n (fun i -> Obj.magic (v i)) in - pretty_blist pp_print_float ff fs - -let pretty_matrix m n ff v = - let fs = List.init m (fun i -> List.init n (fun j -> Obj.magic (v i j))) in - pretty_blist (pretty_blist pp_print_float) ff fs - -let pretty_definition_function_types_interp ff dft value = - begin match dft with - | DTfloat -> pp_print_float ff (Obj.magic value) - | DTVector m -> pretty_vector m ff (Obj.magic value) - | DTMatrix (m,n) -> pretty_matrix m n ff (Obj.magic value) - end - -let pretty_env_entry_type ff (ExistT ((sv, dft), value)) = - pretty_subVar ff sv ; - pp_print_string ff "->" ; - pretty_definition_function_types_interp ff dft value - -let pretty_df_env ff env = - pretty_blist ~bstart:"{" ~bend:"}" pretty_env_entry_type ff env - -(* This should be replaced with Format.pp_print_option from ocaml >=4.08 *) -let pretty_option ?none some formatter value = - begin match value with - | None -> - begin - match none with - | None -> () - | Some none -> none formatter () - end - | Some value -> some formatter value - end - -let pretty_visible_option some formatter value = pretty_option ~none:(pretty_const_string "None") some formatter value diff --git a/ocaml/src/Pretty.mli b/ocaml/src/Pretty.mli deleted file mode 100644 index 2dc36726..00000000 --- a/ocaml/src/Pretty.mli +++ /dev/null @@ -1,19 +0,0 @@ -open API -open DefinedFunctions -open Vector - -val pretty_subVar : Format.formatter -> coq_SubVar -> unit - -val pretty_definition_function_types : Format.formatter -> definition_function_types -> unit - -val pretty_vector : int -> Format.formatter -> float coq_Vector -> unit -val pretty_matrix : int -> int -> Format.formatter -> float coq_Matrix -> unit - -val pretty_var_type : Format.formatter -> var_type -> unit - -val pretty_env_entry_type : Format.formatter -> env_entry_type -> unit -val pretty_df_env : Format.formatter -> df_env -> unit - -val pretty_visible_option : (Format.formatter -> 'a -> unit) -> Format.formatter -> 'a option -> unit - -val pretty_blist : ?bstart:string -> ?bend:string -> ?bsep:string -> (Format.formatter -> 'a -> unit) -> Format.formatter -> 'a list -> unit diff --git a/ocaml/src/Util.ml b/ocaml/src/Util.ml deleted file mode 100644 index 06252aa5..00000000 --- a/ocaml/src/Util.ml +++ /dev/null @@ -1,28 +0,0 @@ -let string_of_char_list l = - let b = Bytes.create (List.length l) in - let i = ref 0 in - List.iter (fun c -> Bytes.set b !i c; incr i) l; - Bytes.to_string b - -let char_list_of_string s = - let l = ref [] in - String.iter (fun c -> l := c :: !l) s; - List.rev !l - -let read_int_matrix_from_csv name = - let sdata = Csv.load name in - List.map (List.map int_of_string) sdata - -let rec memoized_vector f = - let cache = Hashtbl.create 10 in - begin fun n -> - try Hashtbl.find cache n - with Not_found -> begin - let x = f n in - Hashtbl.add cache n x; x - end - end - -let random_float_vector () = memoized_vector (fun _ -> Random.float 1.0) -let random_float_matrix () = memoized_vector (fun _ -> random_float_vector ()) - diff --git a/ocaml/src/Util.mli b/ocaml/src/Util.mli deleted file mode 100644 index 6d31874f..00000000 --- a/ocaml/src/Util.mli +++ /dev/null @@ -1,9 +0,0 @@ -val string_of_char_list : char list -> string -val char_list_of_string : string -> char list - -val read_int_matrix_from_csv : string -> int list list - -val memoized_vector : (int -> 'a) -> int -> 'a - -val random_float_vector : unit -> int -> float -val random_float_matrix : unit -> int -> int -> float diff --git a/coq/CertRL/LM/README.md b/rocq/CertRL/LM/README.md similarity index 100% rename from coq/CertRL/LM/README.md rename to rocq/CertRL/LM/README.md diff --git a/coq/CertRL/LM/R_compl.v b/rocq/CertRL/LM/R_compl.v similarity index 98% rename from coq/CertRL/LM/R_compl.v rename to rocq/CertRL/LM/R_compl.v index 04d53c24..a60be328 100644 --- a/coq/CertRL/LM/R_compl.v +++ b/rocq/CertRL/LM/R_compl.v @@ -116,9 +116,9 @@ generalize (archimed (ln d / ln k)); intros (Y1,_). rewrite Rmult_comm. apply Rlt_le_trans with (1:=Y1). generalize (up (ln d / ln k)); clear; intros x. -rewrite INR_IZR_INZ, Zabs2Nat.id_abs. +rewrite INR_IZR_INZ, Zabs.inj_Zabs_nat. apply IZR_le. -case (Zabs_spec x); intros (T1,T2); rewrite T2; auto with zarith. +case (Zabs.Zabs_spec x); intros (T1,T2); rewrite T2; auto with zarith. (* k = 0 *) exists 1%nat. rewrite <- Hk';simpl. diff --git a/coq/CertRL/LM/check_sub_structure.v b/rocq/CertRL/LM/check_sub_structure.v similarity index 87% rename from coq/CertRL/LM/check_sub_structure.v rename to rocq/CertRL/LM/check_sub_structure.v index d3c2480b..b4a3ea6e 100644 --- a/coq/CertRL/LM/check_sub_structure.v +++ b/rocq/CertRL/LM/check_sub_structure.v @@ -29,7 +29,7 @@ Context {CCP: compatible_g P}. Record Sg:= mk_Sg { val :> G ; H: P val -}. + }. Lemma Sg_eq: forall (x y:Sg), (val x = val y) -> x = y. Proof. @@ -45,11 +45,11 @@ Qed. Definition Sg_zero : Sg := mk_Sg zero (compatible_g_zero P CCP). Definition Sg_plus (x y : Sg) : Sg := - mk_Sg (plus x y) + mk_Sg (plus (val x) (val y)) (compatible_g_plus P (val x) (val y) CCP (H x) (H y)). Definition Sg_opp (x : Sg) : Sg := - mk_Sg (opp x) + mk_Sg (opp (val x)) (compatible_g_opp P (val x) CCP (H x)). Lemma Sg_plus_comm: forall (x y:Sg), Sg_plus x y = Sg_plus y x. @@ -81,12 +81,18 @@ unfold Sg_plus; simpl. apply plus_opp_r. Qed. +Definition Sg_AbelianMonoid_mixin := + AbelianMonoid.Mixin Sg Sg_plus Sg_zero Sg_plus_comm + Sg_plus_assoc Sg_plus_zero_r. + +Canonical Sg_AbelianMonoid := + AbelianMonoid.Pack Sg (Sg_AbelianMonoid_mixin) Sg. + Definition Sg_AbelianGroup_mixin := - AbelianGroup.Mixin Sg Sg_plus Sg_opp Sg_zero Sg_plus_comm - Sg_plus_assoc Sg_plus_zero_r Sg_plus_opp_r. + AbelianGroup.Mixin Sg_AbelianMonoid Sg_opp Sg_plus_opp_r. Canonical Sg_AbelianGroup := - AbelianGroup.Pack Sg (Sg_AbelianGroup_mixin) Sg. + AbelianGroup.Pack Sg (AbelianGroup.Class _ (Sg_AbelianMonoid_mixin) Sg_AbelianGroup_mixin) Sg. End Subgroups. @@ -135,12 +141,18 @@ unfold Sg_plus; unfold Sg_scal; simpl. apply scal_distr_r. Qed. +Definition Sg_MAbelianMonoid_mixin := + AbelianMonoid.Mixin (@Sg _ P) Sg_Mplus Sg_zero Sg_plus_comm + Sg_plus_assoc Sg_plus_zero_r. + +Canonical Sg_MAbelianMonoid := + AbelianMonoid.Pack(@Sg _ P) (Sg_MAbelianMonoid_mixin) (@Sg _ P). + Definition Sg_MAbelianGroup_mixin := - AbelianGroup.Mixin Sg Sg_Mplus Sg_opp Sg_zero Sg_plus_comm - Sg_plus_assoc Sg_plus_zero_r Sg_plus_opp_r. + AbelianGroup.Mixin Sg_MAbelianMonoid Sg_opp Sg_plus_opp_r. Canonical Sg_MAbelianGroup := - AbelianGroup.Pack Sg (Sg_MAbelianGroup_mixin) (@Sg _ P). + AbelianGroup.Pack (@Sg _ P) (AbelianGroup.Class _ (Sg_MAbelianMonoid_mixin) Sg_MAbelianGroup_mixin) (@Sg _ P). Definition Sg_ModuleSpace_mixin := ModuleSpace.Mixin R_Ring (Sg_MAbelianGroup) diff --git a/coq/CertRL/LM/compatible.v b/rocq/CertRL/LM/compatible.v similarity index 82% rename from coq/CertRL/LM/compatible.v rename to rocq/CertRL/LM/compatible.v index 40361961..573926d6 100644 --- a/coq/CertRL/LM/compatible.v +++ b/rocq/CertRL/LM/compatible.v @@ -133,10 +133,11 @@ intros. split. unfold compatible_m in *. unfold compatible_g in *. assert (u = opp u'). +replace (@zero (ModuleSpace.AbelianMonoid R_Ring E)) with (@zero (AbelianGroup.AbelianMonoid (ModuleSpace.AbelianGroup R_Ring E))) in H2 by reflexivity. rewrite <- plus_opp_r with u' in H2. rewrite plus_comm in H2. -apply plus_reg_l in H2. -trivial. +eapply plus_reg_l. +eapply H2. assert (phi' u). rewrite H3 in H2. rewrite H3. @@ -144,8 +145,9 @@ rewrite <- scal_opp_one. apply (proj2 Cphi'). trivial. apply H; trivial. assert (u' = opp u). +replace (@zero (ModuleSpace.AbelianMonoid R_Ring E)) with (@zero (AbelianGroup.AbelianMonoid (ModuleSpace.AbelianGroup R_Ring E))) in H2 by reflexivity. rewrite <- plus_opp_r with u in H2. -apply plus_reg_l in H2. trivial. +eapply plus_reg_l; eauto. assert (phi u'). rewrite H3 in H2. rewrite H3. @@ -156,7 +158,7 @@ Qed. Lemma plus_u_opp_v : forall (u v : E), u = v <-> (plus u (opp v) = zero). intros; split. -+ intros. rewrite H. rewrite plus_opp_r. reflexivity. ++ intros. rewrite H. now generalize (plus_opp_r v). + intros. apply plus_reg_r with (opp v). rewrite plus_opp_r; trivial. Qed. @@ -179,7 +181,9 @@ destruct Cphi as ((Cphi1,(x,Cphi2)),Cphi3). destruct Cphi' as ((Cphi'1,(x',Cphi'2)),Cphi'3). assert (plus (plus u (opp v)) (plus u' (opp v')) = zero). rewrite plus_assoc_gen. rewrite H4. -rewrite plus_assoc_gen. rewrite plus_opp_r. rewrite plus_opp_r. +rewrite plus_assoc_gen. +replace (plus v' (opp v')) with (@zero (ModuleSpace.AbelianMonoid R_Ring E)) by (symmetry; generalize (plus_opp_r v'); intros HH; apply HH). +replace (plus v (opp v)) with (@zero (ModuleSpace.AbelianMonoid R_Ring E)) by (symmetry; generalize (plus_opp_r v); intros HH; apply HH). rewrite plus_zero_r. reflexivity. rewrite plus_u_opp_v. rewrite (plus_u_opp_v u' v'). @@ -203,9 +207,18 @@ apply (Cphi'3 x (opp one)) in H1. assert ((x = zero) /\ (opp x = zero)). apply H. trivial. rewrite <- (scal_zero_l y). trivial. rewrite <- scal_opp_one. trivial. +simpl. +replace (@zero + (AbelianMonoid.Pack (ModuleSpace.sort R_Ring E) + (AbelianGroup.base (ModuleSpace.sort R_Ring E) + (ModuleSpace.base R_Ring + match E return Type with + | @ModuleSpace.Pack _ T _ _ => T + end (ModuleSpace.class R_Ring E))) + (ModuleSpace.sort R_Ring E))) with (@zero (ModuleSpace.AbelianMonoid R_Ring E)) by reflexivity. rewrite <- (scal_zero_l y'). trivial. -rewrite plus_opp_r. -rewrite plus_zero_l. reflexivity. +rewrite plus_zero_l. +apply (plus_opp_r x). intuition. Qed. diff --git a/coq/CertRL/LM/continuous_linear_map.v b/rocq/CertRL/LM/continuous_linear_map.v similarity index 99% rename from coq/CertRL/LM/continuous_linear_map.v rename to rocq/CertRL/LM/continuous_linear_map.v index cd56c54b..a25ec153 100644 --- a/coq/CertRL/LM/continuous_linear_map.v +++ b/rocq/CertRL/LM/continuous_linear_map.v @@ -56,7 +56,7 @@ specialize (H2 (norm x)). assert (norm x = 0). case (Req_dec (norm x) 0); trivial. intros H3. -elimtype False. +exfalso. apply H2. exists x. split; trivial. @@ -93,7 +93,7 @@ Proof. intros A H1 H2 H3. case H2; trivial. intros H4. -elimtype False. +exfalso. specialize (Is_only_zero_set_correct3 H1). intros H5. assert (~(exists x : E, x <> zero)). @@ -104,7 +104,7 @@ intros x. case (Req_dec (norm x) 0). apply norm_eq_zero. intros H6. -elimtype False. +exfalso. apply H. exists x. intros H7; apply H6. @@ -236,7 +236,6 @@ Proof. intros f. generalize (operator_norm_ge_0 f). destruct (operator_norm f); try easy. -intros _; apply Rle_refl; simpl. Qed. Lemma operator_norm_helper: @@ -889,12 +888,19 @@ unfold clm_plus, clm_opp; simpl. apply: plus_opp_r. Qed. + +Definition clm_AbelianMonoid_mixin := + AbelianMonoid.Mixin clm clm_plus clm_zero clm_plus_comm + clm_plus_assoc clm_plus_zero_r. + +Canonical clm_AbelianMonoid := + AbelianMonoid.Pack clm (clm_AbelianMonoid_mixin) clm. + Definition clm_AbelianGroup_mixin := - AbelianGroup.Mixin clm clm_plus clm_opp clm_zero clm_plus_comm - clm_plus_assoc clm_plus_zero_r clm_plus_opp_r. + AbelianGroup.Mixin clm_AbelianMonoid clm_opp clm_plus_opp_r. Canonical clm_AbelianGroup := - AbelianGroup.Pack clm (clm_AbelianGroup_mixin) clm. + AbelianGroup.Pack clm (AbelianGroup.Class _ (clm_AbelianMonoid_mixin) clm_AbelianGroup_mixin) clm. (** Clm is ModuleSpace *) @@ -1308,7 +1314,7 @@ specialize (H2 (norm x)). assert (norm x = 0). case (Req_dec (norm x) 0); trivial. intros H3. -elimtype False. +exfalso. apply H2. exists x. split; trivial. @@ -1391,7 +1397,7 @@ intros A H1 H2 H3. case H2; trivial. intros H4. apply H3. -elimtype False. +exfalso. apply H1. apply Is_only_zero_set_correct2''_phi. intros x. @@ -1460,7 +1466,6 @@ Proof. intros f. generalize (operator_norm_ge_0_phi f). destruct (operator_norm_phi f); try easy. -intros _; apply Rle_refl; simpl. Qed. Lemma operator_norm_helper'_phi: diff --git a/coq/CertRL/LM/fixed_point.v b/rocq/CertRL/LM/fixed_point.v similarity index 98% rename from coq/CertRL/LM/fixed_point.v rename to rocq/CertRL/LM/fixed_point.v index 7f4bd668..228c7d51 100644 --- a/coq/CertRL/LM/fixed_point.v +++ b/rocq/CertRL/LM/fixed_point.v @@ -105,7 +105,7 @@ Proof. intros f a k p m D H1 H2 H3 H4. case_eq m. (* *) -intros _; rewrite plus_0_r. +intros _; rewrite Nat.add_0_r. assert (L:(0 < k ^ p / (1 - k) * D)). apply Rmult_lt_0_compat; trivial. apply Rdiv_lt_0_compat. @@ -184,7 +184,8 @@ Proof. intros a k p m n D (H1,H1') H2 Phi_a H3 H4 Hp Hm Hn. case H1; intros P. (* *) -case (le_or_lt p m); intros H5. + +case (Nat.le_gt_cases p m); intros H5. (* . *) replace m with (p+(m-p))%nat. apply ball_le with (k^p/(1-k) *D). @@ -197,7 +198,7 @@ apply Rle_pow_le; try assumption. split; try left; try assumption. apply dist_iter; try assumption. split; assumption. -now apply le_plus_minus_r. +auto with arith. (* . *) apply ball_sym. replace p with (m+(p-m))%nat. @@ -211,7 +212,7 @@ apply Rle_pow_le; try assumption. split; try left; assumption. apply dist_iter; try assumption. split; assumption. -now apply le_plus_minus_r, lt_le_weak. +auto with arith. (* *) apply ball_le with 0. rewrite <- P. @@ -219,8 +220,8 @@ rewrite pow_i; try assumption. right; unfold Rdiv; ring. apply dist_iter_aux_zero; try assumption. now rewrite P. -now apply lt_le_trans with n. -now apply lt_le_trans with n. +now apply Nat.lt_le_trans with n. +now apply Nat.lt_le_trans with n. Qed. End iter_dist. @@ -403,11 +404,11 @@ intros P Q (N1,H1) (N2,H2). exists (max N1 N2). intros n Hn; split. apply H1. -apply le_trans with (2:=Hn). -apply Max.le_max_l. +apply Nat.le_trans with (2:=Hn). +apply Nat.le_max_l. apply H2. -apply le_trans with (2:=Hn). -apply Max.le_max_r. +apply Nat.le_trans with (2:=Hn). +apply Nat.le_max_r. intros P Q H (N,HN). exists N. intros n Hn. @@ -746,7 +747,7 @@ Proof. intros f P a k p m D H1 H2 HH H3 H4. case_eq m. (* *) -intros _; rewrite plus_0_r. +intros _; rewrite Nat.add_0_r. assert (L:(0 < k ^ p / (1 - k) * D)). apply Rmult_lt_0_compat; trivial. apply Rdiv_lt_0_compat. @@ -825,7 +826,7 @@ Proof. intros a k p m n D (H1,H1') H2 Phi_a H3 H4 Hp Hm Hn. case H1; intros P0. (* *) -case (le_or_lt p m); intros H5. +case (Nat.le_gt_cases p m); intros H5. (* . *) replace m with (p+(m-p))%nat. apply ball_le with (k^p/(1-k) *D). @@ -842,7 +843,7 @@ intros p0. apply phi_iter_f. trivial. trivial. -now apply le_plus_minus_r. +auto with arith. (* . *) apply ball_sym. replace p with (m+(p-m))%nat. @@ -858,7 +859,7 @@ apply dist_iter_phi with phi; try assumption. split; assumption. intros p0. apply phi_iter_f; trivial. -now apply le_plus_minus_r, lt_le_weak. +auto with arith. (* *) apply ball_le with 0. rewrite <- P0. @@ -866,8 +867,8 @@ rewrite pow_i; try assumption. right; unfold Rdiv; ring. apply dist_iter_aux_zero_phi; try assumption. now rewrite P0. -now apply lt_le_trans with n. -now apply lt_le_trans with n. +now apply Nat.lt_le_trans with n. +now apply Nat.lt_le_trans with n. Qed. End iter_dist_sub. @@ -979,11 +980,11 @@ intros P Q (N1,H1) (N2,H2). exists (max N1 N2). intros n Hn; split. apply H1. -apply le_trans with (2:=Hn). -apply Max.le_max_l. +apply Nat.le_trans with (2:=Hn). +apply Nat.le_max_l. apply H2. -apply le_trans with (2:=Hn). -apply Max.le_max_r. +apply Nat.le_trans with (2:=Hn). +apply Nat.le_max_r. intros P Q H (N,HN). exists N. intros n Hn. diff --git a/coq/CertRL/LM/hierarchyD.v b/rocq/CertRL/LM/hierarchyD.v similarity index 99% rename from coq/CertRL/LM/hierarchyD.v rename to rocq/CertRL/LM/hierarchyD.v index 371ebff9..0575cb48 100644 --- a/coq/CertRL/LM/hierarchyD.v +++ b/rocq/CertRL/LM/hierarchyD.v @@ -416,7 +416,7 @@ Lemma clm_lim_aux2: forall (Fi : ((clm E F) -> Prop) -> Prop), Proof. intros Fi FFi HFc t eps. case (is_zero_dec t); intros Ht. -exists zero. +exists (@zero F). unfold Fri, Fr, cleanFilter. exists (fun _ => True). split. @@ -461,7 +461,7 @@ Definition clm_lim (Fi : ((clm E F) -> Prop) -> Prop) lim (Fri Fi t) (clm_lim_aux1 Fi H1 t) (clm_lim_aux2 Fi H1 H2 t). Lemma ball_plus_plus: forall (x a y b:F) (eps:posreal), ball x eps a -> ball y eps b -> - ball (plus x y) (2*(@norm_factor R_AbsRing F)*eps) (plus a b). + ball (@plus F x y) (2*(@norm_factor R_AbsRing F)*eps) (@plus F a b). Proof. intros x a y b eps Hx Hy. pose (v:=@norm_factor R_AbsRing F); fold v. @@ -642,7 +642,7 @@ apply Rle_trans with (minus (g x) (f x)))). right; f_equal. unfold minus; rewrite <- plus_assoc. -apply trans_eq with (plus (clm_lim Fi H1 H2 x) (opp (f x))). +apply trans_eq with (@plus F (clm_lim Fi H1 H2 x) (opp (f x))). reflexivity. f_equal. rewrite plus_assoc. diff --git a/coq/CertRL/LM/hilbert.v b/rocq/CertRL/LM/hilbert.v similarity index 95% rename from coq/CertRL/LM/hilbert.v rename to rocq/CertRL/LM/hilbert.v index 42cba33f..55b0a345 100644 --- a/coq/CertRL/LM/hilbert.v +++ b/rocq/CertRL/LM/hilbert.v @@ -79,7 +79,7 @@ Module Exports. Coercion base : class_of >-> ModuleSpace.class_of. Coercion mixin : class_of >-> mixin_of. Coercion sort : type >-> Sortclass. -Coercion AbelianGroup : type >-> AbelianGroup.type. +Global Coercion AbelianGroup : type >-> AbelianGroup.type. Canonical AbelianGroup. Coercion ModuleSpace : type >-> ModuleSpace.type. Canonical ModuleSpace. @@ -133,7 +133,7 @@ Proof. apply PreHilbert.ax2. Qed. -Lemma inner_eq_0 : forall x, inner x x = 0 -> x = zero. +Lemma inner_eq_0 : forall x, inner x x = 0 -> x = (@zero E). Proof. apply PreHilbert.ax3. Qed. @@ -143,7 +143,7 @@ Proof. intros x; apply sqrt_pos. Qed. -Lemma norm_eq_0: forall x, Hnorm x = 0 -> x = zero. +Lemma norm_eq_0: forall x, Hnorm x = 0 -> x = (@zero E). Proof. intros x; unfold norm; intros H. assert (inner x x = 0). @@ -163,7 +163,7 @@ intros x y l. now rewrite inner_sym inner_scal_l inner_sym. Qed. -Lemma inner_zero_l: forall x, inner zero x = 0. +Lemma inner_zero_l: forall x, inner (@zero E) x = 0. Proof. intros x. apply trans_eq with (inner (scal 0 zero) x). @@ -172,18 +172,18 @@ rewrite inner_scal_l. apply Rmult_0_l. Qed. -Lemma inner_zero_r: forall x, inner x zero = 0. +Lemma inner_zero_r: forall x, inner x (@zero E) = 0. Proof. intros x. rewrite inner_sym; apply inner_zero_l. Qed. -Lemma inner_plus_l : forall (x y z : E), inner (plus x y) z = inner x z + inner y z. +Lemma inner_plus_l : forall (x y z : E), inner (@plus E x y) z = inner x z + inner y z. Proof. apply PreHilbert.ax5. Qed. -Lemma inner_plus_r : forall (x y z : E), inner x (plus y z) = inner x y + inner x z. +Lemma inner_plus_r : forall (x y z : E), inner x (@plus E y z) = inner x y + inner x z. Proof. intros x y z. now rewrite inner_sym inner_plus_l 2!(inner_sym x). @@ -202,7 +202,7 @@ intros x y. now rewrite inner_sym inner_opp_r inner_sym. Qed. -Lemma norm_zero: Hnorm zero = 0. +Lemma norm_zero: Hnorm (@zero E) = 0. Proof. unfold Hnorm; now rewrite inner_zero_l sqrt_0. Qed. @@ -236,7 +236,7 @@ apply inner_ge_0. Qed. Lemma square_expansion_plus: forall x y, - inner (plus x y) (plus x y) = inner x x + 2 * inner x y + inner y y. + inner (@plus E x y) (plus x y) = inner x x + 2 * inner x y + inner y y. Proof. intros x y. rewrite inner_plus_l 2!inner_plus_r. @@ -260,7 +260,7 @@ Qed. (** Equalities and inequalities *) Lemma parallelogram_id: forall x y, - (inner (plus x y) (plus x y)) + (inner (minus x y) (minus x y)) + (inner (@plus E x y) (plus x y)) + (inner (minus x y) (minus x y)) = 2*((inner x x) + (inner y y)). Proof. intros x y. @@ -299,7 +299,7 @@ now rewrite 2!squared_norm. apply Rmult_le_pos; apply norm_ge_0. Qed. -Lemma norm_triangle: forall x y, Hnorm (plus x y) <= Hnorm x + Hnorm y. +Lemma norm_triangle: forall x y, Hnorm (@plus E x y) <= Hnorm x + Hnorm y. Proof. intros x y. apply Rsqr_incr_0_var; unfold Rsqr. @@ -361,7 +361,7 @@ apply plus_zero_r. Qed. Definition PreHilbert_UniformSpace_mixin := - UniformSpace.Mixin E zero ball ball_center ball_sym ball_triangle. + UniformSpace.Mixin E (@zero E) ball ball_center ball_sym ball_triangle. Canonical PreHilbert_UniformSpace := UniformSpace.Pack E (PreHilbert_UniformSpace_mixin) E. @@ -672,13 +672,13 @@ intros g Hg1 Hg2. unfold Hierarchy.ball; simpl; unfold ball; simpl. simpl in Hf2, Hg2. (* . *) -assert (M:4*Rsqr (norm (minus u (scal (/2) (plus f g)))) + Rsqr (norm (minus f g)) = - 2*Rsqr (norm (minus u f)) + 2*Rsqr (norm (minus u g))). +assert (M: 4*Rsqr (norm (@minus E u (scal (/2) (@plus E f g)))) + Rsqr (norm (@minus E f g)) = + 2*Rsqr (norm (@minus E u f)) + 2*Rsqr (norm (@minus E u g))). unfold Rsqr at 3 4; rewrite 2!squared_norm. rewrite <- Rmult_plus_distr_l. rewrite <- parallelogram_id. f_equal. -apply trans_eq with (Rsqr (2*norm (minus u (scal (/ 2) (plus f g))))). +apply trans_eq with (Rsqr (2*norm (minus u (scal (/ 2) (@plus E f g))))). unfold Rsqr; ring. replace 2 with (abs 2) at 1. 2: apply Rabs_right; left; apply Rlt_0_2. @@ -694,10 +694,15 @@ replace 2 with (plus 1 1). 2: unfold plus; simpl; ring. rewrite scal_distr_r scal_one. unfold minus; rewrite <- 2!plus_assoc; f_equal. -rewrite opp_plus 2!plus_assoc; f_equal. -apply plus_comm. +rewrite <- Rsqr_def. +do 3 f_equal. +rewrite (plus_comm (opp f)). +rewrite <- plus_assoc. +f_equal. +rewrite plus_comm. +apply opp_plus. rewrite <- squared_norm. -fold (Rsqr (Hnorm (minus (minus u f) (minus u g)))). +rewrite <- Rsqr_def. f_equal. rewrite <- norm_opp. unfold norm; simpl; f_equal. @@ -710,10 +715,10 @@ now rewrite plus_opp_r plus_zero_l. apply Rsqr_incrst_0. 2: apply norm_ge_0. 2: left; apply cond_pos. -apply Rle_lt_trans with (-4 * (norm (minus u (scal (/ 2) (plus f g))))² + +apply Rle_lt_trans with (-4 * (norm (minus u (scal (/ 2) (@plus E f g))))² + 2 * (norm (minus u f))² + 2 * (norm (minus u g))²). right. -apply Rplus_eq_reg_l with (4 * (norm (minus u (scal (/ 2) (plus f g))))²). +apply Rplus_eq_reg_l with (4 * (norm (minus u (scal (/ 2) (@plus E f g))))²). rewrite M; ring. apply Rle_lt_trans with (-4*Rsqr delta+2 * (norm (minus u f))² +2 * (norm (minus u g))²). apply Rplus_le_compat_r, Rplus_le_compat_r. @@ -723,10 +728,10 @@ apply Ropp_le_contravar. apply Rmult_le_pos; left; apply Rlt_0_2. apply Rsqr_incr_1; try assumption. assert (H2:Rbar_le (Glb_Rbar (fun r : R => exists w0 : E, phi w0 - /\ r = norm (minus u w0))) (norm (minus u (scal (/ 2) (plus f g))))). + /\ r = norm (@minus E u w0))) (norm (@minus E u (scal (/ 2) (@plus E f g))))). apply Glb_Rbar_correct. -exists (scal (/ 2) (plus f g)); split; try easy. -replace (scal (/ 2) (plus f g)) with +exists (scal (/ 2) (@plus E f g)); split; try easy. +replace (scal (/ 2) (@plus E f g)) with (plus (scal (/2) f) (scal (1-/2) g)). apply phi_convex; try assumption. split. @@ -865,7 +870,7 @@ intros H u v v' H0 H1 H2 H3. rewrite <- H3 in H1. pose (a := minus u v'). pose (b := minus u v). -pose (v'' := scal (1/2) (plus v v')). +pose (v'' := scal (1/2) (@plus E v v')). pose (d := norm (minus u v)). assert (E1 : plus (4*((norm (minus u v''))*(norm (minus u v'')))) ((norm (minus v v'))*(norm (minus v v'))) @@ -905,12 +910,12 @@ rewrite plus_comm. rewrite (plus_comm _ ((norm (minus v v') * norm (minus v v')))). apply Rplus_eq_compat_l. replace (plus (minus u v') (minus u v)) - with (minus ((scal 2) u) (plus v v')). + with (minus ((scal 2) u) (@plus E v v')). unfold minus. -assert (H42 : (4 *(norm (plus u (opp (scal (1 / 2) (plus v v')))) * - norm (plus u (opp (scal (1 / 2) (plus v v')))))) - =((2*(norm (plus u (opp (scal (1 / 2) (plus v v'))))))* - (2*(norm (plus u (opp (scal (1 / 2) (plus v v')))))))). +assert (H42 : (4 *(norm (@plus E u (opp (scal (1 / 2) (@plus E v v')))) * + norm (@plus E u (opp (scal (1 / 2) (@plus E v v')))))) + =((2*(norm (@plus E u (opp (scal (1 / 2) (@plus E v v'))))))* + (2*(norm (@plus E u (opp (scal (1 / 2) (@plus E v v')))))))). ring. rewrite H42. replace 2 with (Rabs 2) at 1. @@ -954,12 +959,9 @@ rewrite plus_comm. rewrite opp_opp. rewrite (plus_comm (opp u) v). rewrite plus_assoc. -rewrite <- (plus_assoc u (opp v') v). -rewrite (plus_comm u (plus (opp v') v)). -rewrite <- (plus_assoc (plus (opp v') v) u (opp u)). -rewrite plus_opp_r. -rewrite plus_zero_r. -reflexivity. +rewrite (plus_comm _ (opp u)). +repeat rewrite plus_assoc. +now rewrite plus_opp_l plus_zero_l. assert (Hmin : norm (minus v v') * norm (minus v v') <= 0). replace 0 with (plus (4*(d*d)) (-4*(d*d))). replace (norm (minus v v') * norm (minus v v')) with @@ -968,8 +970,8 @@ replace (norm (minus v v') * norm (minus v v')) with unfold minus. apply Rplus_le_compat_l. replace (-4 * (d * d)) with (opp (4*(d*d))). -assert (Has: opp (4 * (norm (plus u (opp v'')) * norm (plus u (opp v'')))) - = (-((4 * (norm (plus u (opp v'')) * norm (plus u (opp v''))))))). +assert (Has: opp (4 * (norm (@plus E u (opp v'')) * norm (@plus E u (opp v'')))) + = (-((4 * (norm (@plus E u (opp v'')) * norm (@plus E u (opp v''))))))). reflexivity. rewrite Has. clear Has. @@ -1001,7 +1003,7 @@ assert (Hf :Rbar_le x (norm (minus u v''))). apply Hp. unfold v''. unfold convex in phi_convex. -replace (scal (1 / 2) (plus v v')) +replace (scal (1 / 2) (@plus E v v')) with (plus (scal (1/2) v) (scal (1/2) v')). replace (1 / 2) with (1 - (1 / 2)) at 2 by field. @@ -1123,7 +1125,7 @@ intros. unfold minus at 1. rewrite scal_distr_r. rewrite scal_one. rewrite scal_opp_l. rewrite opp_plus. rewrite (opp_plus v (opp (scal t v))). rewrite opp_opp. rewrite plus_assoc. rewrite plus_assoc. -rewrite (plus_comm (plus u (opp (scal t w))) (opp v)). +rewrite (plus_comm (@plus E u (opp (scal t w))) (opp v)). rewrite plus_assoc. unfold minus at 1. rewrite (plus_comm (opp v) u). unfold minus. rewrite <- plus_assoc. rewrite <- scal_opp_l. @@ -1227,7 +1229,7 @@ Lemma charac_ortho_projection_convex_aux1_r : forall u v w :E, phi v -> <= ((norm (minus u w) * norm (minus u w)))). intros. assert (minus u w = plus (minus u v) (minus v w)). -unfold minus. rewrite (plus_comm u (opp v)). +unfold minus. rewrite (@plus_comm E u (opp v)). rewrite plus_assoc_gen. rewrite plus_opp_l. rewrite plus_zero_l. reflexivity. rewrite H2. @@ -1332,7 +1334,7 @@ intros; split. intros. assert (forall w, minus u w = plus (minus u v) (minus v w)). - unfold minus. symmetry. rewrite (plus_comm u (opp v)). + unfold minus. symmetry. rewrite (@plus_comm E u (opp v)). rewrite plus_assoc_gen. rewrite plus_opp_l. rewrite plus_zero_l. reflexivity. assert @@ -1433,7 +1435,7 @@ apply mod2. trivial. apply unique_existence1; split. apply ortho_projection_convex'; try easy. -exists zero. +exists (@zero E). apply (compatible_m_zero phi phi_mod). intros v v' Hv Hv'. apply (ortho_projection_convex_unique phi phi_convex phi_compl u); intuition. @@ -1489,7 +1491,7 @@ rewrite inner_scal_l. rewrite (H1 _ Hy); ring. Qed. -Lemma trivial_orth_compl : forall u : E, ((forall v : E, inner u v = 0) <-> u = zero). +Lemma trivial_orth_compl : forall u : E, ((forall v : E, inner u v = 0) <-> u = @zero E). intros; split. intro H0. assert (inner u u = 0). apply H0; trivial. apply PreHilbert.ax3 in H. trivial. @@ -1497,7 +1499,7 @@ intros. rewrite H. rewrite inner_zero_l; reflexivity. Qed. Lemma trivial_orth_compl' : forall (phi : E -> Prop) (u : E), - closed phi -> phi u -> ((forall v : E, phi v -> inner u v = 0) <-> u = zero). + closed phi -> phi u -> ((forall v : E, phi v -> inner u v = 0) <-> u = @zero E). intros; split. intro H1. assert (inner u u = 0). apply H1; trivial. apply PreHilbert.ax3 in H2. trivial. @@ -1524,7 +1526,7 @@ split; intros. + assert (forall w:E, phi w -> inner (minus u v) (minus w v) <= 0). apply charac_ortho_projection_subspace1; intuition. - pose (w' := plus w v). + pose (w' := @plus E w v). assert (inner (minus u v) w <= 0). assert (Hm1 : w = minus w' v). unfold minus. @@ -1670,7 +1672,7 @@ Qed. Lemma direct_sum_with_orth_compl_charac2: forall u v, phi v -> norm (minus u v) = Glb_Rbar (fun r => exists w:E, phi w /\ r = norm (minus u w)) -> - (orth_compl u <-> v = zero). + (orth_compl u <-> v = @zero E). split; intros. + unfold orth_compl in H1. assert @@ -1688,7 +1690,7 @@ split; intros. intros. apply direct_sum_with_orth_compl_decomp in H0. assert - (forall x : E, phi x -> orth_compl x -> x = zero). + (forall x : E, phi x -> orth_compl x -> x = @zero E). apply direct_sumable_with_orth_compl. unfold orth_compl in H0. assert (inner (minus u zero) (minus w zero) = 0). unfold minus. rewrite opp_zero. diff --git a/coq/CertRL/LM/lax_milgram.v b/rocq/CertRL/LM/lax_milgram.v similarity index 98% rename from coq/CertRL/LM/lax_milgram.v rename to rocq/CertRL/LM/lax_milgram.v index 1d860f94..83e54f31 100644 --- a/coq/CertRL/LM/lax_milgram.v +++ b/rocq/CertRL/LM/lax_milgram.v @@ -206,7 +206,7 @@ Theorem Riesz_Frechet'_zero_phi : forall (f:topo_dual E), phi v -> f v = inner u v. Proof. intros f H. -exists zero. +exists (@zero E). split. destruct m_C as ((CG1,(z,CG2)),CS). unfold compatible_m in m_C. @@ -242,9 +242,9 @@ apply C1; trivial. apply C1'; trivial. exists zero. split. -apply compatible_m_zero. +apply (compatible_m_zero phi0). apply C. -apply compatible_m_zero. +apply (compatible_m_zero phi0'). apply C'. intros x l. split. @@ -316,7 +316,7 @@ assert (forall u, exists! v:E, = Glb_Rbar (fun r => exists w:E, PHI w /\ r = norm (minus u w))). intros u; apply: ortho_projection_convex. -exists zero. +exists (@zero E). split. (*apply ker_nnker_equiv.*) apply: compatible_m_zero. @@ -772,7 +772,7 @@ exfalso. apply H. exists v. split; trivial. -assert (u = zero). +assert (u = @zero E). rewrite <- H0. assert (exists! (u:E), phi u /\ forall (v:E), phi v -> f v = inner u v). @@ -803,7 +803,7 @@ trivial. apply H2'. rewrite H2. rewrite norm_zero. -assert (forall u, u <> zero -> phi u -> norm (f u) <= 0 * norm u). +assert (forall u, u <> @zero E -> phi u -> norm (f u) <= 0 * norm u). intros u0 H3 H3'. unfold f_phi_neq_zero. rewrite H1. @@ -861,7 +861,7 @@ apply H0. apply (iota_elim _ _ ) in H0. unfold tau. rewrite H0. -assert (u <> zero). +assert (u <> @zero E). rewrite Heq. destruct H1 as (H11,H12). assert (f v = inner u1 v). @@ -1720,7 +1720,7 @@ destruct H as (H,H2). rewrite <- plus_assoc in H. destruct H as (Pu,H). symmetry in H. -assert (Hu0 : plus u zero = u). +assert (Hu0 : @plus E u zero = u). rewrite plus_zero_r. reflexivity. rewrite <- Hu0 in H at 1. apply plus_reg_l in H. @@ -1812,7 +1812,7 @@ rewrite <- plus_assoc. rewrite <- plus_assoc. rewrite plus_comm. rewrite (plus_comm (opp v') _). -assert (Hp1 : forall a b, plus (plus a b) v +assert (Hp1 : forall a b, plus (@plus E a b) v = plus a (plus b v)). intros. rewrite plus_assoc. @@ -1820,7 +1820,7 @@ reflexivity. rewrite Hp1. rewrite Hp1. rewrite Hp1. -rewrite (plus_assoc v (opp v') _). +rewrite (@plus_assoc E v (opp v') _). assert (Hp2 : forall a b, plus a (plus b (plus (opp v') v)) = plus (plus a b) (plus (opp v') v)). @@ -1830,30 +1830,30 @@ rewrite Hp2. rewrite Hp2. rewrite plus_comm. rewrite (plus_comm (opp v') v). -assert (Hpaux : ((plus (opp (scal r (Tau (A v)))) +assert (Hpaux : ((@plus E (opp (scal r (Tau (A v)))) (plus (scal r (Tau f)) (plus (opp (opp (scal r (Tau (A v'))))) (opp (scal r (Tau f))))))) - = (opp (scal r (Tau (A (plus v (opp v'))))))). + = (opp (scal r (Tau (A (@plus E v (opp v'))))))). rewrite plus_assoc. rewrite opp_opp. rewrite (plus_comm _ (opp (scal r (Tau f)))). -assert (Hg : forall a b c d : E, plus (plus a b) (plus c d) - = plus (plus a d) (plus b c)). +assert (Hg : forall a b c d : E, @plus E (@plus E a b) (@plus E c d) + = @plus E (@plus E a d) (@plus E b c)). intros. rewrite plus_assoc. rewrite plus_assoc. rewrite plus_comm. -rewrite (plus_comm a0 d). -rewrite <- (plus_assoc d a0 b). -rewrite <- (plus_assoc d _ _). +rewrite (@plus_comm E a0 d). +rewrite <- (@plus_assoc E d a0 b). +rewrite <- (@plus_assoc E d _ _). reflexivity. rewrite Hg. rewrite plus_opp_r. rewrite plus_zero_r. -replace (opp (scal r (Tau (A (plus v (opp v')))))) +replace (opp (scal r (Tau (A (@plus E v (opp v')))))) with - (scal r (Tau (A (opp (plus v (opp v')))))). + (scal r (Tau (A (opp (@plus E v (opp v')))))). replace ((opp (scal r (Tau (A v))))) with (scal r (Tau (A (opp v)))). @@ -1891,14 +1891,14 @@ rewrite Hsr. reflexivity. rewrite <- scal_opp_r. rewrite <- scal_opp_one. -replace ((scal (opp one) (Tau (A (plus v (opp v')))))) +replace ((scal (opp one) (Tau (A (@plus E v (opp v')))))) with - ((Tau (scal (opp one) (A (plus v (opp v')))))). + ((Tau (scal (opp one) (A (@plus E v (opp v')))))). assert (H : is_linear_mapping Tau). apply Riesz_Frechet_moreover2_phi; trivial. destruct H as (Hl1,Hl2). -replace (Tau (A (opp (plus v (opp v'))))) - with (Tau (scal (opp one) (A (plus v (opp v'))))). +replace (Tau (A (opp (@plus E v (opp v'))))) + with (Tau (scal (opp one) (A (@plus E v (opp v'))))). reflexivity. clear Hp1 Hp2 Hg. assert (Hlb : forall l y, scal l (A y) @@ -2353,7 +2353,7 @@ intros u' Hu'. symmetry. apply (proj2 Hu). trivial. -exists zero. +exists (@zero E). split. unfold is_sol_linear_pb_phi. split. @@ -2363,12 +2363,12 @@ intros v Hv. apply Is_only_zero_set_correct1_phi with E phi v in i. rewrite i. destruct Hba as ((Hba1,(Hba2,Hba3)),Hba4). -assert (a zero zero = a (scal zero zero) zero). +assert (a (@zero E) (@zero E) = a (scal zero zero) (@zero E)). rewrite scal_zero_l. reflexivity. rewrite H. -replace (a (scal zero zero) zero) - with (scal zero (a zero zero)). +replace (a (scal zero zero) (@zero E)) + with (scal zero (a (@zero E) (@zero E))). rewrite scal_zero_l. assert (is_linear_mapping f). apply f. diff --git a/coq/CertRL/LM/lax_milgram_cea.v b/rocq/CertRL/LM/lax_milgram_cea.v similarity index 94% rename from coq/CertRL/LM/lax_milgram_cea.v rename to rocq/CertRL/LM/lax_milgram_cea.v index 032ae62d..bb73b3d4 100644 --- a/coq/CertRL/LM/lax_milgram_cea.v +++ b/rocq/CertRL/LM/lax_milgram_cea.v @@ -138,10 +138,11 @@ destruct H0 as (H0,H7). destruct H7 as (H7,(H8,H9)). unfold minus. rewrite H9. -replace (a (plus u (opp uh)) u) with - (plus (a (plus u (opp uh)) u) 0). +transitivity (plus (a (@plus E u (opp uh)) u) 0); cycle 1. +{ unfold plus; simpl. + apply Rplus_0_r. +} f_equal. -now rewrite plus_zero_r. specialize (H1 uh). apply H1 in X1. specialize (H2 uh). @@ -165,10 +166,12 @@ rewrite H7 scal_opp_one. reflexivity. rewrite scal_opp_one. reflexivity. -rewrite plus_zero_r; reflexivity. -replace (M * norm (minus u vh) * norm (minus u uh)) with - (M * norm (minus u uh) * norm (minus u vh)) by ring. -assumption. +eapply Rle_trans; try eapply H4. +right. +simpl. +repeat rewrite Rmult_assoc. +f_equal. +now rewrite Rmult_comm. destruct Hca. intro Hk. rewrite Hk in H0. diff --git a/coq/CertRL/LM/linear_map.v b/rocq/CertRL/LM/linear_map.v similarity index 88% rename from coq/CertRL/LM/linear_map.v rename to rocq/CertRL/LM/linear_map.v index 16d90a67..94b708f3 100644 --- a/coq/CertRL/LM/linear_map.v +++ b/rocq/CertRL/LM/linear_map.v @@ -66,15 +66,21 @@ Lemma fct_plus_opp_r: forall f:E->F, fct_plus f (fct_opp f) = fct_zero. Proof. intros f. apply functional_extensionality. -intros x; apply plus_opp_r. +intros x; apply (@plus_opp_r F). Qed. +Definition fct_AbelianMonoid_mixin := + AbelianMonoid.Mixin (E->F) fct_plus fct_zero fct_plus_comm + fct_plus_assoc fct_plus_zero_r. + +Canonical fct_AbelianMonoid := + AbelianMonoid.Pack (E->F) (fct_AbelianMonoid_mixin) (E->F). + Definition fct_AbelianGroup_mixin := - AbelianGroup.Mixin (E->F) fct_plus fct_opp fct_zero fct_plus_comm - fct_plus_assoc fct_plus_zero_r fct_plus_opp_r. + AbelianGroup.Mixin fct_AbelianMonoid fct_opp fct_plus_opp_r. Canonical fct_AbelianGroup := - AbelianGroup.Pack (E->F) (fct_AbelianGroup_mixin) (E->F). + AbelianGroup.Pack (E->F) (AbelianGroup.Class _ (fct_AbelianMonoid_mixin) fct_AbelianGroup_mixin) (E->F). Lemma fct_scal_assoc: forall x y (u:E->F), fct_scal x (fct_scal y u) = fct_scal (mult x y) u. @@ -138,7 +144,10 @@ split. unfold plus at 1 4 5; unfold opp; simpl. unfold fct_plus, fct_opp. rewrite Hf1, Hg1. - rewrite opp_plus. + transitivity (plus (plus (f x) (f y)) (plus (opp (g x)) (opp (g y)))). + { f_equal. + apply (opp_plus (g x) (g y)). + } repeat rewrite <- plus_assoc. apply f_equal. repeat rewrite plus_assoc. @@ -193,7 +202,7 @@ Proof. intros f (H1,H2); split. intros x y; unfold opp; simpl; unfold fct_opp. rewrite H1. -apply opp_plus. +apply (opp_plus (f x) (f y)). intros x l; unfold opp; simpl; unfold fct_opp. rewrite H2. now rewrite <- scal_opp_r. @@ -254,13 +263,23 @@ intros x y l;unfold plus; unfold opp; simpl;unfold fct_plus, fct_opp;unfold plus unfold fct_plus, fct_opp;rewrite Hf2,Hg2;rewrite <- scal_opp_r;now rewrite scal_distr_l. split. intros x y z; unfold plus at 1 4 5; unfold opp;simpl; unfold fct_plus, fct_opp; - unfold plus at 1 5 6;unfold opp;simpl;unfold fct_plus, fct_opp;rewrite Hf3,Hg3; - rewrite opp_plus;rewrite plus_assoc;rewrite plus_assoc;apply f_equal2;trivial. - rewrite <- plus_assoc;rewrite (plus_comm (f y z) (opp (g x z)));now rewrite plus_assoc. -intros x y z; unfold plus at 1 4 5; unfold opp;simpl; unfold fct_plus, fct_opp. - unfold plus at 1 4 5;unfold opp;simpl;unfold fct_plus, fct_opp. rewrite Hf4,Hg4; - rewrite opp_plus;rewrite plus_assoc;rewrite plus_assoc;apply f_equal2;trivial. - rewrite <- plus_assoc;rewrite (plus_comm (f x z) (opp (g x y)));now rewrite plus_assoc. + unfold plus at 1 5 6;unfold opp;simpl;unfold fct_plus, fct_opp;rewrite Hf3,Hg3. +repeat rewrite <- plus_assoc. +f_equal. +rewrite (plus_comm (opp (g x z))). +repeat rewrite <- plus_assoc. +f_equal. +rewrite (plus_comm (opp _)). +apply (opp_plus (g x z) (g y z)). +intros x y z; unfold plus at 1 4 5; unfold opp;simpl; unfold fct_plus, fct_opp; + unfold plus at 1 4 5;unfold opp;simpl;unfold fct_plus, fct_opp. rewrite Hf4,Hg4. +repeat rewrite <- plus_assoc. +f_equal. +rewrite (plus_comm (opp (g x y))). +repeat rewrite <- plus_assoc. +f_equal. +rewrite (plus_comm (opp _)). +apply (opp_plus (g x y) (g x z)). exists zero;split. unfold zero;intros;simpl;unfold fct_zero; simpl;unfold zero;simpl;unfold fct_zero; now rewrite scal_zero_r. diff --git a/coq/CertRL/LM/logic_tricks.v b/rocq/CertRL/LM/logic_tricks.v similarity index 96% rename from coq/CertRL/LM/logic_tricks.v rename to rocq/CertRL/LM/logic_tricks.v index 5190454a..d350b6cf 100644 --- a/coq/CertRL/LM/logic_tricks.v +++ b/rocq/CertRL/LM/logic_tricks.v @@ -118,17 +118,16 @@ intros k Hk. replace k with 0%nat. apply H. intros m Hm; contradict Hm. -apply lt_n_0. +auto with arith. generalize Hk; case k; trivial. intros m Hm; contradict Hm. -apply le_not_lt. -now auto with arith. +now auto with arith. intros k Hk. apply H. intros m Hm. apply IHn. -apply lt_le_trans with (1:=Hm). -now apply gt_S_le. +apply Nat.lt_le_trans with (1:=Hm). +auto with arith. apply H0. apply le_n. Qed. diff --git a/coq/CertRL/README.md b/rocq/CertRL/README.md similarity index 100% rename from coq/CertRL/README.md rename to rocq/CertRL/README.md diff --git a/coq/CertRL/cond_expt.v b/rocq/CertRL/cond_expt.v similarity index 100% rename from coq/CertRL/cond_expt.v rename to rocq/CertRL/cond_expt.v diff --git a/coq/CertRL/finite_time.v b/rocq/CertRL/finite_time.v similarity index 100% rename from coq/CertRL/finite_time.v rename to rocq/CertRL/finite_time.v diff --git a/coq/CertRL/mdp.v b/rocq/CertRL/mdp.v similarity index 99% rename from coq/CertRL/mdp.v rename to rocq/CertRL/mdp.v index 5c721f05..eede0e86 100644 --- a/coq/CertRL/mdp.v +++ b/rocq/CertRL/mdp.v @@ -61,12 +61,12 @@ Record MDP := mkMDP { act : forall s: state, Type; (** The state space has decidable equality.*) - st_eqdec :> EqDec state eq; - act_eqdec :> (forall s, EqDec (act s) eq); + st_eqdec ::> EqDec state eq; + act_eqdec ::> (forall s, EqDec (act s) eq); (** The state and action spaces are finite. *) - fs :> FiniteType (state) ; - fa :> forall s, FiniteType (act s); + fs ::> FiniteType (state) ; + fa ::> forall s, FiniteType (act s); (** The state space and the fibered action spaces are nonempty. *) ne : NonEmpty (state) ; @@ -96,8 +96,6 @@ Proof. eapply FiniteType_fun_dep ; eauto. - apply fs. - apply fa. - Unshelve. - apply st_eqdec. Qed. Global Instance act_finite (M : MDP) : FiniteType (sigT M.(act)) @@ -407,12 +405,19 @@ Proof. lra. Qed. + +Definition Rfct_AbelianMonoid_mixin := + AbelianMonoid.Mixin (Rfct A) Rfct_plus Rfct_zero Rfct_plus_comm + Rfct_plus_assoc Rfct_plus_zero_r. + +Canonical Rfct_AbelianMonoid := + AbelianMonoid.Pack (Rfct A) (Rfct_AbelianMonoid_mixin) (Rfct A). + Definition Rfct_AbelianGroup_mixin := - AbelianGroup.Mixin (Rfct A) Rfct_plus Rfct_opp Rfct_zero Rfct_plus_comm - Rfct_plus_assoc Rfct_plus_zero_r Rfct_plus_opp_r. + AbelianGroup.Mixin Rfct_AbelianMonoid Rfct_opp Rfct_plus_opp_r. Canonical Rfct_AbelianGroup := - AbelianGroup.Pack (Rfct A) (Rfct_AbelianGroup_mixin) (Rfct A). + AbelianGroup.Pack (Rfct A) (AbelianGroup.Class _ (Rfct_AbelianMonoid_mixin) Rfct_AbelianGroup_mixin) (Rfct A). End Rfct_AbelianGroup. diff --git a/coq/CertRL/mdp_turtle.v b/rocq/CertRL/mdp_turtle.v similarity index 96% rename from coq/CertRL/mdp_turtle.v rename to rocq/CertRL/mdp_turtle.v index d296dd12..d8168f26 100644 --- a/coq/CertRL/mdp_turtle.v +++ b/rocq/CertRL/mdp_turtle.v @@ -33,14 +33,14 @@ Section turtle. Definition turtle_state max_x max_y := prod ({x:nat | x < max_x}%nat) ({y:nat | y < max_y}%nat). (* Convenience method for creating a state with known in-bounds constant co-ordinates *) - Definition make_turtle_state max_x max_y x y : if lt_dec x max_x - then if lt_dec y max_y + Definition make_turtle_state max_x max_y x y : if Compare_dec.lt_dec x max_x + then if Compare_dec.lt_dec y max_y then turtle_state max_x max_y else True else True. Proof. - destruct (lt_dec x max_x). - - destruct (lt_dec y max_y). + destruct (Compare_dec.lt_dec x max_x). + - destruct (Compare_dec.lt_dec y max_y). + apply pair. * exists x; trivial. * exists y; trivial. @@ -126,9 +126,9 @@ Section turtle. := (let '(x,y) := s in match a with | Up => if proj1_sig y == 0 then None else Some (x, y-1) - | Down => if lt_dec (y+1) max_y then Some (x, y+1) else None + | Down => if Compare_dec.lt_dec (y+1) max_y then Some (x, y+1) else None | Left => if proj1_sig x == 0 then None else Some (x-1, y) - | Right => if lt_dec (x+1) max_x then Some (x+1, y) else None + | Right => if Compare_dec.lt_dec (x+1) max_x then Some (x+1, y) else None end)%nat. Next Obligation. lia. @@ -253,7 +253,7 @@ End optimal_path. Section to_string. Section utils. - Definition newline := String (Ascii.ascii_of_N 10) EmptyString. + Definition newline := String (Ascii.ascii_of_N (Npos 10%positive)) EmptyString. Definition string_bracket (sstart send:string) (smiddle:string) := String.append sstart (String.append smiddle send). diff --git a/coq/CertRL/orderfun.v b/rocq/CertRL/orderfun.v similarity index 100% rename from coq/CertRL/orderfun.v rename to rocq/CertRL/orderfun.v diff --git a/coq/CertRL/pmf_monad.v b/rocq/CertRL/pmf_monad.v similarity index 100% rename from coq/CertRL/pmf_monad.v rename to rocq/CertRL/pmf_monad.v diff --git a/coq/CertRL/pmf_prob.v b/rocq/CertRL/pmf_prob.v similarity index 98% rename from coq/CertRL/pmf_prob.v rename to rocq/CertRL/pmf_prob.v index 3d67c646..6e9f1b10 100644 --- a/coq/CertRL/pmf_prob.v +++ b/rocq/CertRL/pmf_prob.v @@ -127,7 +127,7 @@ Section Pmf_PMF. rewrite Forall_forall; intros lmu. specialize (lmu _ inn). lia. - -- destruct (in_dec eq_nat_dec a l); trivial. + -- destruct (in_dec Peano_dec.eq_nat_dec a l); trivial. specialize (nin _ n). unfold nequiv_decb, negb, equiv_decb in non0. destruct (equiv_dec (coll a) 0); congruence. @@ -283,7 +283,7 @@ Section Pmf_PMF. (map (fun x : nonnegreal * A => nonneg (fst x)) (filter (fun x : nonnegreal * A => if equiv_dec a (snd x) then true else false) outcomes)) | None => 0 - end) (nodup eq_nat_dec (map countable_index (map snd outcomes))))). + end) (nodup Peano_dec.eq_nat_dec (map countable_index (map snd outcomes))))). - apply infinite_sum'_finite ; intros. + match_case; intros. @@ -389,7 +389,7 @@ Section pmf_prob. (exist (fun _ : pre_event A => True) (fun omega : A => pre_event_singleton c (rv_X omega)) (sa_preimage_singleton rv_X c))) - (nodup eq_nat_dec (map countable_index (map snd pmf.(outcomes))))) + (nodup Peano_dec.eq_nat_dec (map countable_index (map snd pmf.(outcomes))))) ; intros HH. cut_to HH. - rewrite (infinite_sum'_unique i HH). diff --git a/coq/CertRL/qvalues.v b/rocq/CertRL/qvalues.v similarity index 100% rename from coq/CertRL/qvalues.v rename to rocq/CertRL/qvalues.v diff --git a/coq/CertRL/refs.md b/rocq/CertRL/refs.md similarity index 100% rename from coq/CertRL/refs.md rename to rocq/CertRL/refs.md diff --git a/coq/ProbTheory/Almost.v b/rocq/ProbTheory/Almost.v similarity index 99% rename from coq/ProbTheory/Almost.v rename to rocq/ProbTheory/Almost.v index b6b73a82..8723f05a 100644 --- a/coq/ProbTheory/Almost.v +++ b/rocq/ProbTheory/Almost.v @@ -481,7 +481,7 @@ Section almostR2_part. apply all_almost; intros ω Pω. exists N; trivial. - intros. - apply le_dec. + apply Compare_dec.le_dec. - trivial. Qed. diff --git a/coq/ProbTheory/BorelSigmaAlgebra.v b/rocq/ProbTheory/BorelSigmaAlgebra.v similarity index 100% rename from coq/ProbTheory/BorelSigmaAlgebra.v rename to rocq/ProbTheory/BorelSigmaAlgebra.v diff --git a/coq/ProbTheory/ConditionalExpectation.v b/rocq/ProbTheory/ConditionalExpectation.v similarity index 100% rename from coq/ProbTheory/ConditionalExpectation.v rename to rocq/ProbTheory/ConditionalExpectation.v diff --git a/coq/ProbTheory/DiscreteProbSpace.v b/rocq/ProbTheory/DiscreteProbSpace.v similarity index 99% rename from coq/ProbTheory/DiscreteProbSpace.v rename to rocq/ProbTheory/DiscreteProbSpace.v index 820341c9..77abce95 100644 --- a/coq/ProbTheory/DiscreteProbSpace.v +++ b/rocq/ProbTheory/DiscreteProbSpace.v @@ -1209,7 +1209,7 @@ Section countable_products. intros. apply Rge_minus. unfold double_sum. - destruct (lt_dec n2 n1). + destruct (Compare_dec.lt_dec n2 n1). - rewrite (sum_f_R0_split _ n1 n2); trivial. apply Rge_trans with (r2 := sum_f_R0 (fun i : nat => sum_f_R0 (fun j : nat => f i j) m1) n2). + rewrite <- Rplus_0_r. @@ -1220,7 +1220,7 @@ Section countable_products. intros; apply H. + apply Rle_ge, sum_f_R0_le. intros. - destruct (lt_dec m2 m1). + destruct (Compare_dec.lt_dec m2 m1). * rewrite (sum_f_R0_split _ m1 m2); trivial. rewrite <- Rplus_0_r at 1. apply Rplus_le_compat_l. @@ -1230,7 +1230,7 @@ Section countable_products. - assert (n1 = n2) by lia; subst. apply Rle_ge, sum_f_R0_le. intros. - destruct (lt_dec m2 m1). + destruct (Compare_dec.lt_dec m2 m1). + rewrite (sum_f_R0_split _ m1 m2); trivial. apply Rplus_le_pos_l. apply sum_f_R0_nneg. @@ -1284,7 +1284,7 @@ Section countable_products. Rabs ((double_sum f m n) - (double_sum f n n)) = Rabs ((double_sum f m m) - (double_sum f n n))). { - destruct (ge_dec m n)%nat. + destruct (Compare_dec.ge_dec m n)%nat. rewrite Rabs_right, Rabs_right, Rabs_right; try lra. apply double_sum_ge; trivial; lia. apply double_sum_ge; trivial; lia. @@ -1385,7 +1385,7 @@ Section countable_products. rewrite <- Lim_seq_const. apply Lim_seq_le_loc. exists n; intros. - destruct (lt_dec n n0). + destruct (Compare_dec.lt_dec n n0). generalize (sum_f_R0_split f n0 n); intros. rewrite H1; try lia. apply Rplus_le_pos_l. diff --git a/coq/ProbTheory/Dynkin.v b/rocq/ProbTheory/Dynkin.v similarity index 99% rename from coq/ProbTheory/Dynkin.v rename to rocq/ProbTheory/Dynkin.v index fd42f9d0..f819f2b3 100644 --- a/coq/ProbTheory/Dynkin.v +++ b/rocq/ProbTheory/Dynkin.v @@ -24,7 +24,7 @@ Section dynkin. Class Lambda_system (c:pre_event T -> Prop) := mk_lambda_system { lambda_Ω : c pre_Ω - ; lambda_proper :> Proper (pre_event_equiv ==> iff) c + ; lambda_proper ::> Proper (pre_event_equiv ==> iff) c ; lambda_complement {a} : c a -> c (pre_event_complement a) ; lambda_disjoint_countable_union (an : nat -> pre_event T) : (forall x, c (an x)) -> @@ -654,7 +654,7 @@ Section monotone_class_def. Class monotone_class (M : pre_event E -> Prop) := mk_monotone_class { monotone_Ω : M pre_Ω - ; monotone_proper :> Proper (pre_event_equiv ==> iff) M + ; monotone_proper ::> Proper (pre_event_equiv ==> iff) M ; monotone_diff a b : M a -> M b -> pre_event_sub a b -> diff --git a/coq/ProbTheory/Event.v b/rocq/ProbTheory/Event.v similarity index 99% rename from coq/ProbTheory/Event.v rename to rocq/ProbTheory/Event.v index a2dac2b3..01b05fba 100644 --- a/coq/ProbTheory/Event.v +++ b/rocq/ProbTheory/Event.v @@ -2151,7 +2151,7 @@ Section event. unfold collection_take. split. - intros na. - destruct (lt_dec a n). + destruct (Compare_dec.lt_dec a n). + split; trivial. destruct (map_nth_in_exists En (seq 0 n) event_none a). * now rewrite seq_length. @@ -2357,7 +2357,7 @@ Section vec. Definition pre_bounded_inter_of_pre_collection_unbound {T} {n} (collection: forall i (pf:(i match lt_dec i n with + (fun i => match Compare_dec.lt_dec i n with | left pf => collection i pf | right _ => pre_Ω end). @@ -2396,7 +2396,7 @@ Section vec. Definition bounded_inter_of_collection_unbound {T} {n} {σ: SigmaAlgebra T} (collection: forall i (pf:(i match lt_dec i n with + (fun i => match Compare_dec.lt_dec i n with | left pf => collection i pf | right _ => Ω end). @@ -2719,7 +2719,7 @@ Section pre_make_disjoint. Definition make_pre_collection_disjoint (coll:nat->pre_event T) : nat -> pre_event T := fun x => pre_event_diff (coll x) ((pre_union_of_collection (fun y => - if lt_dec y x + if Compare_dec.lt_dec y x then coll y else pre_event_none))). @@ -2755,13 +2755,13 @@ Section pre_make_disjoint. intros y ylt cy. apply H2. exists y. - destruct (lt_dec y x); intuition. + destruct (Compare_dec.lt_dec y x); intuition. - intros [ce fce]. unfold make_pre_collection_disjoint. split; trivial. unfold pre_union_of_collection. intros [n Hn]. - destruct (lt_dec n x); trivial. + destruct (Compare_dec.lt_dec n x); trivial. eapply fce; eauto. Qed. @@ -2773,7 +2773,7 @@ Section pre_make_disjoint. apply make_pre_collection_disjoint_in in e2. destruct e1 as [H11 H12]. destruct e2 as [H21 H22]. - destruct (not_eq _ _ xyneq) as [xlt|ylt]. + destruct (Compare_dec.not_eq _ _ xyneq) as [xlt|ylt]. - eapply H22; eauto. - eapply H12; eauto. Qed. @@ -2797,7 +2797,7 @@ Section pre_make_disjoint. split; trivial. unfold pre_union_of_collection. intros [nn Hnn]. - destruct (lt_dec nn m); [ | tauto]. + destruct (Compare_dec.lt_dec nn m); [ | tauto]. specialize (H0 _ Hnn). lia. - apply make_pre_collection_disjoint_in in Hn. @@ -2870,7 +2870,7 @@ Section more_pre_props. unfold collection_take. split. - intros na. - destruct (lt_dec a n). + destruct (Compare_dec.lt_dec a n). + split; trivial. destruct (map_nth_in_exists En (seq 0 n) pre_event_none a). * now rewrite seq_length. @@ -2923,8 +2923,8 @@ Section more_pre_props. Qed. Lemma pre_union_of_collection_lt_S {A} (E:nat->pre_event A) n : - pre_event_equiv (pre_union_of_collection (fun y : nat => if lt_dec y (S n) then E y else pre_event_none)) - (pre_event_union (E n) (pre_union_of_collection (fun y : nat => if lt_dec y n then E y else pre_event_none))). + pre_event_equiv (pre_union_of_collection (fun y : nat => if Compare_dec.lt_dec y (S n) then E y else pre_event_none)) + (pre_event_union (E n) (pre_union_of_collection (fun y : nat => if Compare_dec.lt_dec y n then E y else pre_event_none))). Proof. intros ?; split. - intros [? HH]. diff --git a/coq/ProbTheory/Expectation.v b/rocq/ProbTheory/Expectation.v similarity index 99% rename from coq/ProbTheory/Expectation.v rename to rocq/ProbTheory/Expectation.v index 0974f7bc..c2bfb0dc 100644 --- a/coq/ProbTheory/Expectation.v +++ b/rocq/ProbTheory/Expectation.v @@ -1,4 +1,4 @@ -Require Import Reals. +Require Import ZArith Reals. Require Import Lra Lia. Require Import List Permutation. @@ -1074,7 +1074,7 @@ Section Expectation_sec. * invcs H1. rewrite H4 in H3. congruence. - + destruct (gt_dec n 0). + + destruct (Compare_dec.gt_dec n 0). * generalize (find_none _ _ H2); intros. specialize (H3 0). rewrite <- in_rev in H3. @@ -1095,7 +1095,7 @@ Section Expectation_sec. generalize (pow_exp_gt 2 n); intros. cut_to H4. replace (0%nat) with (n*0)%nat at 1 by lia. - apply mult_lt_compat_l; lia. + apply Nat.mul_lt_mono_pos_l; lia. lia. * assert (n = 0)%nat by lia. invcs H3. @@ -1142,13 +1142,13 @@ Section Expectation_sec. clear n0. assert (pos1:(n * 2 ^ n > 0)%nat). { - apply NPeano.Nat.mul_pos_pos. + apply Nat.mul_pos_pos. - destruct n; try lia. simpl in Xlt. specialize (posX omega). now apply Rbar_le_not_lt in posX. - simpl. - apply NPeano.Nat.Private_NZPow.pow_pos_nonneg + apply Nat.Private_NZPow.pow_pos_nonneg ; lia. } match_case; intros. @@ -1191,7 +1191,7 @@ Section Expectation_sec. simpl in HH. rewrite app_ass in HH. rewrite app_nth2 in HH by lia. - rewrite NPeano.Nat.sub_diag in HH. + rewrite Nat.sub_diag in HH. simpl in HH. subst. split; intros. @@ -1210,7 +1210,7 @@ Section Expectation_sec. rewrite last_app in eqq0 by congruence. simpl in eqq0. rewrite <- eqq0. - rewrite NPeano.Nat.sub_1_r. + rewrite Nat.sub_1_r. rewrite Nat.succ_pred_pos by trivial. rewrite mult_INR. simpl. @@ -1234,7 +1234,7 @@ Section Expectation_sec. rewrite eqq4 in Fl1. invcs Fl1. match_destr_in H1. - rewrite NPeano.Nat.add_1_r in n0. + rewrite Nat.add_1_r in n0. rewrite Rbar_div_Rdiv. now apply Rbar_not_le_lt in n0. + rewrite app_length; simpl. @@ -1251,8 +1251,8 @@ Section Expectation_sec. rewrite <- Rbar_div_Rdiv in r0. rewrite Rbar_le_div_l in r0 ; [| apply pow2_pos]. { - destruct (lt_eq_lt_dec (length d) k) as [[lt1|]|lt1]; trivial - ; elimtype False. + destruct (Compare_dec.lt_eq_lt_dec (length d) k) as [[lt1|]|lt1]; trivial + ; exfalso. - generalize (f_equal (fun x => nth k x 0)%nat); intros HH. specialize (HH _ _ eqq2). { @@ -1363,14 +1363,7 @@ Section Expectation_sec. intros posX. intros omega k. intros klt. - assert (pos1:(n * 2 ^ n > 0)%nat). - { - apply NPeano.Nat.mul_pos_pos. - - destruct n; try lia. - - simpl. - apply NPeano.Nat.Private_NZPow.pow_pos_nonneg - ; lia. - } + assert (pos1:(n * 2 ^ n > 0)%nat) by lia. unfold simple_approx. split; intros HH. - match_destr_in HH. @@ -1414,7 +1407,7 @@ Section Expectation_sec. rewrite last_app in eqq0 by congruence. simpl in eqq0. subst. - rewrite NPeano.Nat.sub_1_r. + rewrite Nat.sub_1_r. rewrite Nat.succ_pred_pos by trivial. rewrite mult_INR. unfold Rdiv. @@ -1453,7 +1446,7 @@ Section Expectation_sec. rewrite app_length in HH. replace ((S (length (rev l2)) - (length (rev l2) + length [length (rev l2)])))%nat with 0%nat in HH. * rewrite rev_nth in HH by (simpl; lia). - rewrite plus_0_l in HH. + rewrite Nat.add_0_l in HH. rewrite HH. rewrite Forall_forall in Fl1. specialize (Fl1 (nth (length (n1 :: l1) - 1) (n1 :: l1) 0%nat)). @@ -1537,8 +1530,8 @@ Section Expectation_sec. subst. apply Rmult_eq_compat_r. f_equal. - destruct (lt_eq_lt_dec (length (rev l2)) k) as [[lt1|]|lt1]; trivial - ; elimtype False. + destruct (Compare_dec.lt_eq_lt_dec (length (rev l2)) k) as [[lt1|]|lt1]; trivial + ; exfalso. - generalize (f_equal (fun x => nth k x ((fun x : nat => INR x / 2 ^ n) 0%nat))); intros HH. specialize (HH _ _ eqq2). { @@ -1991,7 +1984,7 @@ Section Expectation_sec. apply Rinv_0_lt_compat. apply cond_pos. ++ apply Rle_pow; [lra |]. - apply Max.le_max_l. + apply Nat.le_max_l. * rewrite Rinv_involutive. reflexivity. apply Rgt_not_eq. @@ -2006,7 +1999,7 @@ Section Expectation_sec. specialize (posX ω). now rewrite <- isfin in posX; simpl in posX. * apply le_INR. - apply Max.le_max_r. + apply Nat.le_max_r. Qed. Lemma simple_approx_lim_seq (X:Ts -> Rbar) (posX : Rbar_NonnegativeFunction X) : diff --git a/coq/ProbTheory/FunctionsToReal.v b/rocq/ProbTheory/FunctionsToReal.v similarity index 97% rename from coq/ProbTheory/FunctionsToReal.v rename to rocq/ProbTheory/FunctionsToReal.v index 8cb665fd..e4ae64ef 100644 --- a/coq/ProbTheory/FunctionsToReal.v +++ b/rocq/ProbTheory/FunctionsToReal.v @@ -81,7 +81,7 @@ Section defs. (fun omega => (rv_X1 omega) + (rv_X2 omega)). Definition rvsum (Xn : nat -> Ts -> R) (n : nat) := - (fun omega => sum_n (fun n0 => Xn n0 omega) n). + (fun omega => @sum_n R_AbelianGroup (fun n0 => Xn n0 omega) n). Definition rvscale (c:R) (rv_X : Ts -> R) := fun omega => c * (rv_X omega). @@ -845,7 +845,7 @@ Section defs. Qed. Lemma Rbar_Rabs_lim_sum_le0 (f : nat -> Ts -> R) (x : Ts) : - is_finite (Lim_seq (sum_n (fun n=> Rabs (f n x)))) -> + is_finite (Lim_seq (@sum_n R_AbelianGroup (fun n=> Rabs (f n x)))) -> Rbar_le (Rbar_abs (Lim_seq (fun n => (rvsum f) n x))) (Rbar_abs (Lim_seq (fun n => (rvsum (fun n0 => (rvabs (f n0))) n x)))). @@ -859,11 +859,11 @@ Section defs. apply ex_series_Lim_seq in H. apply ex_series_Lim_seq in H0. replace (Lim_seq - (fun n : nat => sum_n (fun n0 : nat => f n0 x) n)) + (fun n : nat => @sum_n R_AbelianGroup (fun n0 : nat => f n0 x) n)) with (Finite ( Series (fun n : nat => f n x))). replace (Lim_seq (fun n : nat => - sum_n (fun n0 : nat => Rabs (f n0 x)) n)) + @sum_n R_AbelianGroup (fun n0 : nat => Rabs (f n0 x)) n)) with (Finite (Series (fun n : nat => Rabs (f n x)))). simpl. apply Rge_le. @@ -875,7 +875,7 @@ Section defs. Qed. Lemma Rabs_lim_sum_le0 (f : nat -> Ts -> R) (x : Ts) : - is_finite (Lim_seq (sum_n (fun n=> Rabs (f n x)))) -> + is_finite (Lim_seq (@sum_n R_AbelianGroup (fun n=> Rabs (f n x)))) -> Rbar_le (Rbar_abs (Finite (real (Lim_seq (fun n => (rvsum f) n x))))) (Rbar_abs (Lim_seq (fun n => (rvsum (fun n0 => (rvabs (f n0))) n x)))). @@ -889,11 +889,11 @@ Section defs. apply ex_series_Lim_seq in H. apply ex_series_Lim_seq in H0. replace (Lim_seq - (fun n : nat => sum_n (fun n0 : nat => f n0 x) n)) + (fun n : nat => @sum_n R_AbelianGroup (fun n0 : nat => f n0 x) n)) with (Finite ( Series (fun n : nat => f n x))). replace (Lim_seq (fun n : nat => - sum_n (fun n0 : nat => Rabs (f n0 x)) n)) + @sum_n R_AbelianGroup (fun n0 : nat => Rabs (f n0 x)) n)) with (Finite (Series (fun n : nat => Rabs (f n x)))). simpl. apply Rge_le. @@ -915,11 +915,11 @@ Section defs. - rewrite <- H. apply Rbar_Rabs_lim_sum_le0. unfold rvsum, rvabs in H. - replace (Lim_seq (sum_n (fun n : nat => Rabs (f n x)))) + replace (Lim_seq (@sum_n R_AbelianGroup (fun n : nat => Rabs (f n x)))) with (Lim_seq (fun n : nat => - sum_n (fun n0 : nat => Rabs (f n0 x)) n)). + @sum_n R_AbelianGroup (fun n0 : nat => Rabs (f n0 x)) n)). now rewrite H. apply Lim_seq_ext. intros; trivial. @@ -949,11 +949,11 @@ Section defs. - rewrite <- H. apply Rabs_lim_sum_le0. unfold rvsum, rvabs in H. - replace (Lim_seq (sum_n (fun n : nat => Rabs (f n x)))) + replace (Lim_seq (@sum_n R_AbelianGroup (fun n : nat => Rabs (f n x)))) with (Lim_seq (fun n : nat => - sum_n (fun n0 : nat => Rabs (f n0 x)) n)). + @sum_n R_AbelianGroup (fun n0 : nat => Rabs (f n0 x)) n)). now rewrite H. apply Lim_seq_ext. intros; trivial. diff --git a/coq/ProbTheory/Gaussian.v b/rocq/ProbTheory/Gaussian.v similarity index 100% rename from coq/ProbTheory/Gaussian.v rename to rocq/ProbTheory/Gaussian.v diff --git a/coq/ProbTheory/Independence.v b/rocq/ProbTheory/Independence.v similarity index 100% rename from coq/ProbTheory/Independence.v rename to rocq/ProbTheory/Independence.v diff --git a/coq/ProbTheory/Martingale.v b/rocq/ProbTheory/Martingale.v similarity index 98% rename from coq/ProbTheory/Martingale.v rename to rocq/ProbTheory/Martingale.v index ada6417a..4e7b17c7 100644 --- a/coq/ProbTheory/Martingale.v +++ b/rocq/ProbTheory/Martingale.v @@ -500,7 +500,7 @@ Section martingale. FiniteExpectation prts (Y s) <= FiniteExpectation prts (Y t). Proof. intros s t sltt. - destruct (le_lt_or_eq _ _ sltt). + destruct (proj1 (Nat.lt_eq_cases _ _) sltt). - eapply is_sub_martingale_lt in mart; try eapply H. assert (rv1:RandomVariable dom borel_sa (FiniteConditionalExpectation prts (sub s) (Y t))). { @@ -546,7 +546,7 @@ Section martingale. cut (forall s t, (s <= t)%nat -> FiniteExpectation prts (Y s) = FiniteExpectation prts (Y t)). { intros. - destruct (NPeano.Nat.le_ge_cases s t). + destruct (Nat.le_ge_cases s t). - now apply H. - symmetry; now apply H. } @@ -876,7 +876,7 @@ Section martingale. end} with | None => right (fun x => x) - | Some a => le_dec a n + | Some a => Compare_dec.le_dec a n end). @@ -931,7 +931,7 @@ Section martingale. split; intros HH2. * invcs HH2. reflexivity. - * apply le_n_0_eq in HH2; congruence. + * apply Nat.le_0_r in HH2; congruence. + generalize (HH (S n)); intros HH1. generalize (HH n); intros HH2. apply sa_complement in HH2. @@ -954,7 +954,7 @@ Section martingale. unfold stopping_time_pre_event, const. apply sa_sigma_const. destruct c. - - destruct (le_dec n0 n); try tauto. + - destruct (Compare_dec.le_dec n0 n); try tauto. - tauto. Qed. @@ -1341,14 +1341,14 @@ Section martingale. intros ?. unfold pre_event_union, pre_event_complement. split; intros. - - destruct (le_dec n0 n1); eauto. + - destruct (Compare_dec.le_dec n0 n1); eauto. - destruct H; tauto. } rewrite eqq2; clear eqq2. apply sa_union. - apply sa_complement. apply sa_sigma_const. - destruct (le_dec n0 n1); tauto. + destruct (Compare_dec.le_dec n0 n1); tauto. - apply stop'. Qed. @@ -1724,13 +1724,14 @@ Section martingale. unfold rvsum. destruct k; simpl. - lra. - - generalize (@Hierarchy.sum_n_plus Hierarchy.R_AbelianGroup + - generalize (@Hierarchy.sum_n_plus + (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n0 : nat => H1 (S n0) a * (X (S n0) a + -1 * X n0 a)) (fun n0 : nat => H2 (S n0) a * (X (S n0) a + -1 * X n0 a)) k); intros eqq. unfold Hierarchy.plus in eqq; simpl in eqq. rewrite <- eqq. - apply (@Hierarchy.sum_n_ext Hierarchy.R_AbelianGroup); intros. + apply (@Hierarchy.sum_n_ext (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup)); intros. lra. Qed. @@ -1741,7 +1742,7 @@ Section martingale. unfold martingale_transform. destruct k; trivial. unfold rvsum; simpl. - apply (@Hierarchy.sum_n_ext Hierarchy.R_AbelianGroup); intros. + apply (@Hierarchy.sum_n_ext Hierarchy.R_AbelianMonoid); intros. rv_unfold. now rewrite eqq1, eqq2. Qed. @@ -1797,7 +1798,7 @@ Section martingale. * destruct HH as [??]. f_equal. apply antisymmetry - ; apply not_lt + ; apply Nat.le_ngt ; intros HH. -- eapply classic_min_of_some_first in H; eauto. -- specialize (H1 _ HH). @@ -1811,7 +1812,7 @@ Section martingale. - apply sa_inter. + apply adaptX. + apply sa_pre_countable_inter; intros. - destruct (lt_dec n0 n). + destruct (Compare_dec.lt_dec n0 n). * apply (sa_proper _ (fun x => ~ B (X n0 x))). -- intros ?; tauto. -- apply sa_complement. @@ -1872,7 +1873,7 @@ Section martingale. f_equal; lia. + apply sa_countable_union; intros. unfold IsStoppingTime, stopping_time_pre_event in IHn. - * destruct (lt_dec n0 a). + * destruct (Compare_dec.lt_dec n0 a). -- apply sa_inter. ++ apply sa_sigma_const. now left. @@ -1926,7 +1927,7 @@ Section martingale. + split; [| congruence]. intros [?[?[??]]]; congruence. - apply sa_countable_union; intros old. - destruct (le_dec old n). + destruct (Compare_dec.le_dec old n). + apply sa_inter. * eapply sa_proper; try eapply sa_all. firstorder. @@ -1972,7 +1973,7 @@ Section martingale. shelve. { intros ?. - rewrite plus_0_r. + rewrite Nat.add_0_r. reflexivity. } { @@ -1980,7 +1981,7 @@ Section martingale. } Unshelve. match_destr. - now rewrite plus_0_r. + now rewrite Nat.add_0_r. Qed. Lemma hitting_time_from_is_stop @@ -1992,7 +1993,7 @@ Section martingale. unfold hitting_time_from, hitting_time. intros ?. unfold stopping_time_pre_event. - destruct (le_dec old n). + destruct (Compare_dec.le_dec old n). - apply (sa_proper _ (fun x => B (X (n)%nat x) /\ forall k, (old <= k < n)%nat -> ~ B (X k x))). { @@ -2002,7 +2003,7 @@ Section martingale. + destruct HH as [??]. f_equal. apply antisymmetry - ; apply not_lt + ; apply Nat.le_ngt ; intros HH. * apply (classic_min_of_some_first _ _ H (n-old)); [lia |]. now replace (n - old + old)%nat with n by lia. @@ -2026,9 +2027,9 @@ Section martingale. apply sa_inter. + apply adaptX. + apply sa_pre_countable_inter; intros. - destruct (le_dec old n0). + destruct (Compare_dec.le_dec old n0). { - destruct (lt_dec n0 n). + destruct (Compare_dec.lt_dec n0 n). - apply (sa_proper _ (fun x => ~ B (X n0 x))). + intros ?; tauto. + apply sa_complement. diff --git a/coq/ProbTheory/MartingaleConvergence.v b/rocq/ProbTheory/MartingaleConvergence.v similarity index 98% rename from coq/ProbTheory/MartingaleConvergence.v rename to rocq/ProbTheory/MartingaleConvergence.v index 2fdcd2d9..51977e50 100644 --- a/coq/ProbTheory/MartingaleConvergence.v +++ b/rocq/ProbTheory/MartingaleConvergence.v @@ -80,7 +80,7 @@ Section mct. Definition upcrossing_var_expr a b n ts k := match upcrossing_times a b (2*k) ts with | None => 0%nat - | Some upn => if le_dec upn n then k else 0%nat + | Some upn => if Compare_dec.le_dec upn n then k else 0%nat end. Definition upcrossing_var a b n (ts:Ts) : R @@ -190,7 +190,7 @@ Section mct. } * apply sa_countable_union; intros. { - destruct (le_dec n0 m)%nat. + destruct (Compare_dec.le_dec n0 m)%nat. - apply sa_inter. + generalize (upcrossing_times_is_stop a b (2 * n - 1) n0); unfold IsStoppingTime, stopping_time_pre_event. eapply is_filtration_le; trivial. @@ -209,7 +209,7 @@ Section mct. split. - match_destr. intros HH. - destruct (le_dec (S m) n0)%nat; trivial. + destruct (Compare_dec.le_dec (S m) n0)%nat; trivial. elim HH. eexists; split; [reflexivity |]. lia. @@ -223,7 +223,7 @@ Section mct. apply sa_complement. apply sa_countable_union; intros. { - destruct (le_dec n0 m)%nat. + destruct (Compare_dec.le_dec n0 m)%nat. - apply sa_inter. + generalize (upcrossing_times_is_stop a b (2 * n) n0); unfold IsStoppingTime, stopping_time_pre_event. eapply is_filtration_le; trivial. @@ -267,7 +267,7 @@ Section mct. Proof. replace (x + S x)%nat with (S (x + x))%nat by lia. rewrite Nat.even_succ. - rewrite <- NPeano.Nat.negb_even. + rewrite <- Nat.negb_even. now rewrite plus_self_even. Qed. @@ -425,7 +425,7 @@ Section mct. match_case; intros. match_case_in IHh; intros. + rewrite H2 in IHh. - eapply lt_trans. + eapply Nat.lt_trans. apply IHh. apply upcrossing_times_some with (n0 := n1) in H1; trivial; try lia. + apply upcrossing_times_none in H2; try lia. @@ -466,7 +466,7 @@ Section mct. Proof. intros. destruct (upcrossing_times_some_S a b (S k) a0 n1 H0); intros. - destruct (lt_dec 0 k). + destruct (Compare_dec.lt_dec 0 k). - generalize (upcrossing_times_some a b k a0 n0 x l H H1); intros. generalize (upcrossing_times_some a b (S k) a0 x n1); intros. cut_to H3; try lia; trivial. @@ -630,7 +630,7 @@ Section mct. - contrapose. intros. assert (m1 <= m0)%nat by lia. - destruct (lt_dec m1 m0). + destruct (Compare_dec.lt_dec m1 m0). + generalize (upcrossing_times_monotonic_l a b a0 n1 n0 m1 m0 ); intros. cut_to H5; trivial. lia. @@ -710,12 +710,12 @@ Section mct. assert (2 * k + 1 < 2*S x)%nat by lia. apply upcrossing_times_none_plus_alt with (kk := (2 * S x)%nat) in H1; try lia. congruence. - + destruct (lt_dec (2 * S x)%nat (2 * k)%nat). + + destruct (Compare_dec.lt_dec (2 * S x)%nat (2 * k)%nat). * apply upcrossing_times_none_plus_alt with (kk := (2 * k)%nat) in H6; try lia. congruence. - * destruct (lt_dec (2 * k)%nat (2 * S x)%nat). + * destruct (Compare_dec.lt_dec (2 * k)%nat (2 * S x)%nat). -- assert (2 * k + 1 <= 2 * S x - 1)%nat by lia. - destruct (lt_dec (2 * k + 1)%nat (2 * S x - 1)%nat). + destruct (Compare_dec.lt_dec (2 * k + 1)%nat (2 * S x - 1)%nat). ++ apply upcrossing_times_none_plus_alt with (kk := (2 * S x - 1)%nat) in H1; try lia. congruence. ++ assert ( 2 * k + 1 = 2 * S x - 1)%nat by lia. @@ -794,8 +794,8 @@ Section mct. (k2 < k)%nat. Proof. intros. - destruct (le_dec k k2); try lia. - destruct (lt_dec k k2). + destruct (Compare_dec.le_dec k k2); try lia. + destruct (Compare_dec.lt_dec k k2). - now apply (upcrossing_times_none_plus_alt a b k k2 a0) in H; try lia. - assert (k = k2) by lia. now rewrite H1 in H. @@ -829,18 +829,18 @@ Section mct. - match_case_in H; intros; rewrite H5 in H; try easy. match_case_in H; intros; rewrite H6 in H; try easy. destruct H. - destruct (lt_dec (S x) k). + destruct (Compare_dec.lt_dec (S x) k). + assert (n1 < n2)%nat. { specialize (H0 n1 n2 (2 * S x)%nat (2 * k)%nat). cut_to H0; try lia; trivial. } lia. - + destruct (lt_dec k (S x)). + + destruct (Compare_dec.lt_dec k (S x)). * assert (2 * k + 1 <= 2 * S x - 1)%nat by lia. assert (n3 <= n0)%nat. { - destruct (lt_dec (2 * k + 1)%nat (2 * S x - 1)%nat). + destruct (Compare_dec.lt_dec (2 * k + 1)%nat (2 * S x - 1)%nat). - specialize (H0 n3 n0 (2 * k + 1)%nat (2 * S x - 1)%nat). cut_to H0; try lia; trivial. - assert (2 * k + 1 = 2 * S x - 1)%nat by lia. @@ -868,7 +868,7 @@ Section mct. match_case_in H; intros; rewrite H8 in H. + assert (n2 <= n0)%nat. { - destruct (lt_dec (2 * k + 1)%nat (2 * S x - 1)%nat). + destruct (Compare_dec.lt_dec (2 * k + 1)%nat (2 * S x - 1)%nat). - specialize (H0 n2 n0 (2 * k + 1)%nat (2 * S x - 1)%nat). cut_to H0; try lia; trivial. - assert (2 * k + 1 = 2 * S x - 1)%nat by lia. @@ -877,7 +877,7 @@ Section mct. now invcs H3. } lia. - + destruct (lt_dec (2 * k + 1)%nat (2 * S x - 1)). + + destruct (Compare_dec.lt_dec (2 * k + 1)%nat (2 * S x - 1)). * now apply upcrossing_times_none_plus_alt with (kk := (2*S x - 1)%nat) in H8. * assert (2 * k + 1 = 2 * S x - 1)%nat by lia. rewrite H9 in H8. @@ -921,14 +921,14 @@ Section mct. generalize (upcrossing_times_0 a b a0 n1 n2); intros. cut_to H8; trivial; try lia. } - destruct (lt_dec (S x) k). + destruct (Compare_dec.lt_dec (S x) k). + assert (n2 < n0)%nat. { specialize (H5 n2 n0 (2 * S x)%nat (2 * k)%nat). cut_to H5; try lia; trivial. } lia. - + destruct (lt_dec k (S x)). + + destruct (Compare_dec.lt_dec k (S x)). * assert (2 * k + 1 < 2 * S x)%nat by lia. apply (upcrossing_times_none_plus_alt a b (2 * k + 1)%nat (2 * S x)%nat a0) in H2; try lia. congruence. @@ -943,14 +943,14 @@ Section mct. replace (2 * 0)%nat with 0%nat in H7 by lia. congruence. } - destruct (lt_dec (S x) k). + destruct (Compare_dec.lt_dec (S x) k). + apply (upcrossing_times_none_plus_alt a b (2 * S x)%nat (2 * k)%nat a0) in H7; try lia. congruence. - + destruct (lt_dec k (S x)). + + destruct (Compare_dec.lt_dec k (S x)). * assert (2 * k +1 <= 2 * S x - 1)%nat by lia. assert (upcrossing_times a b (2 * S x - 1)%nat a0 = None). { - destruct (lt_dec (2 * k + 1)%nat (2 * S x - 1)%nat). + destruct (Compare_dec.lt_dec (2 * k + 1)%nat (2 * S x - 1)%nat). - now apply (upcrossing_times_none_plus_alt a b (2 * k + 1)%nat (2 * S x - 1)%nat a0) in H2; try lia. - now replace (2 * k + 1)%nat with (2 * S x - 1)%nat in H2 by lia. } @@ -1564,7 +1564,7 @@ Section mct. Definition upcrossing_var_expr1 a b n ts k := match upcrossing_times a b k ts with | None => 0%nat - | Some upn => if le_dec upn n then k else 0%nat + | Some upn => if Compare_dec.le_dec upn n then k else 0%nat end. Lemma upcrossing_bound_transform_ge_0 a b a0 k n0 n : @@ -1598,7 +1598,7 @@ Section mct. match_case_in H6; intros; rewrite H8 in H6; try easy. rewrite H8 in H7. replace (S n - 1)%nat with n in H7 by lia. - destruct (lt_dec n2 (S n)). + destruct (Compare_dec.lt_dec n2 (S n)). + match_case_in H3; intros; rewrite H9 in H3; rewrite H9 in H7; apply H7; try lia. match_destr_in H3; try lia. + assert (S n <= n2)%nat by lia. @@ -1612,7 +1612,7 @@ Section mct. forall j, (upcrossing_var_expr a b (S n) a0 j <= k)%nat. Proof. intros upk j. - destruct (le_dec j (S n)). + destruct (Compare_dec.le_dec j (S n)). - unfold upcrossing_var, upcrossing_var_expr in *. match_option; [| lia]. match_destr; [| lia]. @@ -1669,7 +1669,7 @@ Section mct. (upcrossing_var_expr a b (S n) a0 x1 <= kk)%nat). { intros. - destruct (le_dec x1 (S n)). + destruct (Compare_dec.le_dec x1 (S n)). - apply INR_le. subst. apply Hin'. @@ -1716,7 +1716,7 @@ Section mct. generalize (upcrossing_bound_range0_init a b a0); intros. unfold Hierarchy.sum_n. match_case_in H1; intros; rewrite H3 in H1. - + destruct (lt_dec n n0). + + destruct (Compare_dec.lt_dec n n0). * rewrite (@Hierarchy.sum_n_m_ext_loc Hierarchy.R_AbelianGroup) with (b := fun n1 => Hierarchy.zero). -- rewrite Hierarchy.sum_n_m_const_zero. @@ -1835,7 +1835,7 @@ Section mct. unfold Hierarchy.zero; simpl; lra. - generalize (upcrossing_bound_transform_helper a b a0 n0); intros. match_case_in H1; intros; rewrite H3 in H1. - + destruct (lt_dec (n1-1)%nat n). + + destruct (Compare_dec.lt_dec (n1-1)%nat n). * unfold Hierarchy.sum_n. rewrite Hierarchy.sum_n_m_Chasles with (m := (n1-1)%nat); try lia. unfold Hierarchy.sum_n in H1. @@ -1892,7 +1892,7 @@ Section mct. rewrite Hierarchy.sum_n_m_zero; try lia. unfold Hierarchy.zero; simpl. lra. - - destruct (le_dec 1 (upcrossing_var_expr a b (S n) a0 (S k))). + - destruct (Compare_dec.le_dec 1 (upcrossing_var_expr a b (S n) a0 (S k))). + transitivity (@Hierarchy.sum_n_m Hierarchy.R_AbelianGroup (fun _ => b - a) @@ -2097,7 +2097,7 @@ Section mct. generalize (classic_min_of_some_first _ _ eqq); simpl; unfold id; intros HHY2. generalize (classic_min_of_some_first _ _ eqq0); simpl; unfold id; intros HHM2. apply antisymmetry - ; apply not_lt + ; apply Nat.le_ngt ; intros HH. + apply (HHY2 _ HH). unfold Y, ϕ. @@ -2131,7 +2131,7 @@ Section mct. generalize (classic_min_of_some_first _ _ eqq); simpl; unfold id; intros HHY2. generalize (classic_min_of_some_first _ _ eqq0); simpl; unfold id; intros HHM2. apply antisymmetry - ; apply not_lt + ; apply Nat.le_ngt ; intros HH. + apply (HHY2 _ HH). unfold Y, ϕ. @@ -2797,7 +2797,7 @@ Section mct. destruct (Qs_between_Rbars _ _ r) as [a [b [age [ab blt]]]]. specialize (H0 a b ab). destruct (is_finite_witness _ H0) as [nmax eqq]. - elimtype False. + exfalso. unfold Rbar_rvlim in eqq. generalize (Elim_seq_incr_elem diff --git a/coq/ProbTheory/MartingaleStopped.v b/rocq/ProbTheory/MartingaleStopped.v similarity index 98% rename from coq/ProbTheory/MartingaleStopped.v rename to rocq/ProbTheory/MartingaleStopped.v index af8e7810..869babac 100644 --- a/coq/ProbTheory/MartingaleStopped.v +++ b/rocq/ProbTheory/MartingaleStopped.v @@ -69,10 +69,10 @@ Section stopped_process. unfold lift1_min. rv_unfold; unfold rvsum. destruct n; match_option. - - destruct (Min.min_dec (S n) n0). + - destruct (Nat.min_dec (S n) n0). + assert (nle: (S n <= n0)%nat) by lia. rewrite e. - rewrite (@Hierarchy.sum_n_ext_loc Hierarchy.R_AbelianGroup _ (fun _ => 0)). + rewrite (@Hierarchy.sum_n_ext_loc (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) _ (fun _ => 0)). * rewrite sum_n_zero. field_simplify. match_destr; try lra. @@ -93,7 +93,7 @@ Section stopped_process. rewrite eqq in p. assert (n0 = S n) by lia. subst. - rewrite (@Hierarchy.sum_n_ext_loc Hierarchy.R_AbelianGroup _ (fun _ => 0)). + rewrite (@Hierarchy.sum_n_ext_loc (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) _ (fun _ => 0)). -- rewrite sum_n_zero. lra. -- intros. @@ -116,7 +116,7 @@ Section stopped_process. elim n. now red. - rewrite Hierarchy.sum_Sn. - destruct (le_lt_dec n2 n). + destruct (Compare_dec.le_lt_dec n2 n). + specialize (IHn l). rewrite <- IHn. unfold Hierarchy.plus; simpl. @@ -126,7 +126,7 @@ Section stopped_process. lia. + assert (n2 = S n) by lia. subst. - rewrite (@Hierarchy.sum_n_ext_loc Hierarchy.R_AbelianGroup _ (fun _ => 0)). + rewrite (@Hierarchy.sum_n_ext_loc Hierarchy.R_AbelianMonoid _ (fun _ => 0)). -- rewrite sum_n_zero. unfold Hierarchy.plus; simpl. match_destr; try lra. @@ -137,7 +137,7 @@ Section stopped_process. assert (S n = n0) by congruence. lia. } - - rewrite (@Hierarchy.sum_n_ext Hierarchy.R_AbelianGroup _ (fun _ => 0)). + - rewrite (@Hierarchy.sum_n_ext Hierarchy.R_AbelianMonoid _ (fun _ => 0)). + rewrite sum_n_zero. field_simplify. match_destr; try lra. @@ -689,7 +689,7 @@ Section stopped_process. unfold Hierarchy.plus; simpl. destruct (Nat.eq_dec n (S n0)). * subst. - assert (0 <= @Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n0 : nat => Rabs (Y n0 x)) n0). + assert (0 <= @Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n0 : nat => Rabs (Y n0 x)) n0). { apply sum_n_nneg; intros. apply Rabs_pos. @@ -1078,7 +1078,7 @@ Section stopped_process. Qed. Lemma Rabs_sum_n_triang f n : - Rabs (@Hierarchy.sum_n Hierarchy.R_AbelianGroup f n) <= @Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun k => Rabs (f k)) n. + Rabs (@Hierarchy.sum_n Hierarchy.R_AbelianMonoid f n) <= @Hierarchy.sum_n Hierarchy.R_AbelianMonoid (fun k => Rabs (f k)) n. Proof. induction n. - now repeat rewrite Hierarchy.sum_O. @@ -1181,7 +1181,7 @@ Section stopped_process. unfold rvsum. simpl. rewrite Rabs_sum_n_triang. - transitivity (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun _ => K) n0). + transitivity (@Hierarchy.sum_n Hierarchy.R_AbelianMonoid (fun _ => K) n0). { apply sum_n_le_loc; intros. replace (Y (S n1) x + -1 * Y n1 x) diff --git a/coq/ProbTheory/Measures.v b/rocq/ProbTheory/Measures.v similarity index 99% rename from coq/ProbTheory/Measures.v rename to rocq/ProbTheory/Measures.v index 508d6f83..633777d8 100644 --- a/coq/ProbTheory/Measures.v +++ b/rocq/ProbTheory/Measures.v @@ -31,7 +31,7 @@ Section measure. Class is_measure (μ:event σ -> Rbar) := mk_measure { - measure_proper :> Proper (event_equiv ==> eq) μ + measure_proper ::> Proper (event_equiv ==> eq) μ ; measure_none : μ event_none = 0%R ; measure_nneg a : Rbar_le 0 (μ a) ; measure_countable_disjoint_union (B:nat->event σ) : @@ -283,7 +283,7 @@ Section outer_measure. Class is_outer_measure (μ:pre_event T -> Rbar) := mk_outer_measure { - outer_measure_proper :> Proper (pre_event_equiv ==> eq) μ + outer_measure_proper ::> Proper (pre_event_equiv ==> eq) μ ; outer_measure_none : μ pre_event_none = 0%R ; outer_measure_nneg a : Rbar_le 0 (μ a) ; outer_measure_countable_subadditive (A:pre_event T) (B:nat->pre_event T) : @@ -295,7 +295,7 @@ Section outer_measure. := mk_outer_measure_alt { outer_measure_alt_none : μ pre_event_none = 0%R ; outer_measure_alt_nneg a : Rbar_le 0 (μ a) - ; outer_measure_alt_monotone :> Proper (pre_event_sub ==> Rbar_le) μ + ; outer_measure_alt_monotone ::> Proper (pre_event_sub ==> Rbar_le) μ ; outer_measure_alt_countable_union (B:nat->pre_event T) : Rbar_le (μ (pre_union_of_collection B)) (ELim_seq (fun i : nat => sum_Rbar_n (fun n : nat => μ (B n)) i)) }. @@ -1597,7 +1597,7 @@ Section premeasure. (* we could generalize events, but that is too much work for now :-) *) Class is_premeasure (λ:alg_set Alg -> Rbar) := mk_premeasure { - premeasure_proper :> Proper (alg_equiv ==> eq) λ + premeasure_proper ::> Proper (alg_equiv ==> eq) λ ; premeasure_none : λ alg_none = 0%R ; premeasure_nneg a : Rbar_le 0 (λ a) ; premeasure_countable_disjoint_union (B:nat->alg_set Alg) : @@ -1714,7 +1714,7 @@ Section premeasure. apply alg_make_collection_disjoint_in in e2. destruct e1 as [H11 H12]. destruct e2 as [H21 H22]. - destruct (not_eq _ _ xyneq) as [xlt|ylt]. + destruct (Compare_dec.not_eq _ _ xyneq) as [xlt|ylt]. - eapply H22; eauto. - eapply H12; eauto. Qed. @@ -2497,7 +2497,7 @@ Section semi_premeasure. Class is_semipremeasure (λ:salg_set SAlg -> Rbar) := mk_semipremeasure { - semipremeasure_proper :> Proper (salg_equiv ==> eq) λ + semipremeasure_proper ::> Proper (salg_equiv ==> eq) λ ; semipremeasure_nneg a : Rbar_le 0 (λ a) ; semipremeasure_list_disjoint_union (B:list (salg_set SAlg)) : diff --git a/coq/ProbTheory/OrthoProject.v b/rocq/ProbTheory/OrthoProject.v similarity index 99% rename from coq/ProbTheory/OrthoProject.v rename to rocq/ProbTheory/OrthoProject.v index 675da36d..907f2bf7 100644 --- a/coq/ProbTheory/OrthoProject.v +++ b/rocq/ProbTheory/OrthoProject.v @@ -371,9 +371,9 @@ Section ortho_project. exists (max NP NQ); intros n nN. split. * apply HP. - generalize (Max.le_max_l NP NQ); lia. + generalize (Nat.le_max_l NP NQ); lia. * apply HQ. - generalize (Max.le_max_r NP NQ); lia. + generalize (Nat.le_max_r NP NQ); lia. + intros P Q Himp [N HP]. exists N; intros n nN. auto. diff --git a/coq/ProbTheory/ProbSpace.v b/rocq/ProbTheory/ProbSpace.v similarity index 99% rename from coq/ProbTheory/ProbSpace.v rename to rocq/ProbTheory/ProbSpace.v index 28f830d5..c9aac906 100644 --- a/coq/ProbTheory/ProbSpace.v +++ b/rocq/ProbTheory/ProbSpace.v @@ -32,7 +32,7 @@ Definition sum_of_probs_equals {T:Type} {σ:SigmaAlgebra T} Class ProbSpace {T:Type} (σ:SigmaAlgebra T) := { ps_P : event σ -> R; - ps_proper :> Proper (event_equiv ==> eq) ps_P ; + ps_proper ::> Proper (event_equiv ==> eq) ps_P ; ps_countable_disjoint_union (collection: nat -> event σ) : (* Assume: collection is a subset of Sigma and its elements are pairwise disjoint. *) @@ -292,7 +292,7 @@ Qed. Definition make_collection_disjoint {T:Type} {σ:SigmaAlgebra T} (coll:nat->event σ) : nat -> event σ := fun x => coll x \ (union_of_collection (fun y => - if lt_dec y x + if Compare_dec.lt_dec y x then coll y else ∅)). @@ -328,13 +328,13 @@ Proof. intros y ylt cy. apply H2. exists y. - destruct (lt_dec y x); intuition. + destruct (Compare_dec.lt_dec y x); intuition. - intros [ce fce]. unfold make_collection_disjoint. split; trivial. unfold union_of_collection. intros [n Hn]. - destruct (lt_dec n x); trivial. + destruct (Compare_dec.lt_dec n x); trivial. eapply fce; eauto. Qed. @@ -346,7 +346,7 @@ Proof. apply make_collection_disjoint_in in e2. destruct e1 as [H11 H12]. destruct e2 as [H21 H22]. - destruct (not_eq _ _ xyneq) as [xlt|ylt]. + destruct (Compare_dec.not_eq _ _ xyneq) as [xlt|ylt]. - eapply H22; eauto. - eapply H12; eauto. Qed. @@ -395,7 +395,7 @@ Section classic. split; trivial. unfold union_of_collection. intros [nn Hnn]. - destruct (lt_dec nn m); [ | tauto]. + destruct (Compare_dec.lt_dec nn m); [ | tauto]. specialize (H0 _ Hnn). lia. - apply make_collection_disjoint_in in Hn. @@ -478,7 +478,7 @@ Section ascending. replace m with (0%nat) by lia. reflexivity. - intros. - apply le_lt_or_eq in H. + apply Nat.lt_eq_cases in H. destruct H. + red in asc. rewrite <- asc. @@ -522,7 +522,7 @@ Section ascending. * now apply make_collection_disjoint_sub. + red. unfold make_collection_disjoint. - destruct (classic (proj1_sig (union_of_collection (fun y : nat => if lt_dec y (S n) then En y else event_none)) a)). + destruct (classic (proj1_sig (union_of_collection (fun y : nat => if Compare_dec.lt_dec y (S n) then En y else event_none)) a)). * destruct H as [x HH2]. match_destr_in HH2; [ | red in HH2; tauto]. left. diff --git a/coq/ProbTheory/ProductSpace.v b/rocq/ProbTheory/ProductSpace.v similarity index 99% rename from coq/ProbTheory/ProductSpace.v rename to rocq/ProbTheory/ProductSpace.v index f89152c9..55aa6516 100644 --- a/coq/ProbTheory/ProductSpace.v +++ b/rocq/ProbTheory/ProductSpace.v @@ -2983,8 +2983,8 @@ Qed. - apply fst_rv. - generalize compose_rv; intros HH. cut ( - RandomVariable (product_sa s (ivector_sa i)) (ivector_nth idx (lt_S_n idx n idx_lt) i) - (ivector_nth idx (lt_S_n idx n idx_lt) ∘ snd)). + RandomVariable (product_sa s (ivector_sa i)) (ivector_nth idx (proj2 (Nat.succ_lt_mono idx n) idx_lt) i) + (ivector_nth idx (proj2 (Nat.succ_lt_mono idx n) idx_lt) ∘ snd)). { apply RandomVariable_proper; try reflexivity. now intros [??]. @@ -3011,12 +3011,12 @@ Qed. Proof. intros. destruct n; try lia. - assert (RandomVariable (ivector_sa (ivector_const (S n) σ)) σ (fun x : ivector T (S n) => ivector_nth idx2 (lt_S_n idx2 n pf2) (ivector_tl x))). + assert (RandomVariable (ivector_sa (ivector_const (S n) σ)) σ (fun x : ivector T (S n) => ivector_nth idx2 (proj2 (Nat.succ_lt_mono idx2 n) pf2) (ivector_tl x))). { generalize (compose_rv (dom1 := (ivector_sa (ivector_const (S n) σ))) (dom2 := (ivector_sa (ivector_const n σ))) ivector_tl - (fun x => ivector_nth idx2 (lt_S_n idx2 n pf2) x)); intros. + (fun x => ivector_nth idx2 (proj2 (Nat.succ_lt_mono idx2 n) pf2) x)); intros. apply H; typeclasses eauto. } assert (independent_rvs (ivector_ps ivec_ps) σ σ ivector_hd @@ -3025,7 +3025,7 @@ Qed. generalize (independent_rv_compose (ivector_ps ivec_ps) σ (ivector_sa (ivector_const n σ)) σ σ ivector_hd ivector_tl - (fun x => x) (fun x => ivector_nth idx2 (lt_S_n idx2 n pf2) x) + (fun x => x) (fun x => ivector_nth idx2 (proj2 (Nat.succ_lt_mono idx2 n) pf2) x) ); intros. cut_to H0. - revert H0. @@ -3059,14 +3059,14 @@ Qed. destruct idx2; [lia |]. destruct idx1. + apply (ivector_nth_independent_rvs_0 (n:=S n) (p,i) idx2). - + generalize (IHn i idx1 idx2 (lt_S_n idx1 n pf1) (lt_S_n idx2 n pf2) (lt_S_n idx1 idx2 H)). + + generalize (IHn i idx1 idx2 (proj2 (Nat.succ_lt_mono idx1 n) pf1) (proj2 (Nat.succ_lt_mono idx2 n) pf2) (proj2 (Nat.succ_lt_mono idx1 idx2) H)). unfold independent_rvs, independent_events; intros HH A B. specialize (HH A B). etransitivity; [| etransitivity; [apply HH |]]. * generalize (product_sa_product p (ivector_ps i) Ω - (rv_preimage (fun tl => ivector_nth idx1 (lt_S_n idx1 n pf1) tl) A - ∩ rv_preimage (fun tl => ivector_nth idx2 (lt_S_n idx2 n pf2) tl) B)); intros HH2. + (rv_preimage (fun tl => ivector_nth idx1 (proj2 (Nat.succ_lt_mono idx1 n) pf1) tl) A + ∩ rv_preimage (fun tl => ivector_nth idx2 (proj2 (Nat.succ_lt_mono idx2 n) pf2) tl) B)); intros HH2. rewrite ps_one, Rmult_1_l in HH2. rewrite <- HH2. apply ps_proper; intros [??]; simpl. @@ -3074,14 +3074,14 @@ Qed. * { f_equal. - generalize (product_sa_product p (ivector_ps i) Ω - (rv_preimage (fun tl => ivector_nth idx1 (lt_S_n idx1 n pf1) tl) A)); intros HH2. + (rv_preimage (fun tl => ivector_nth idx1 (proj2 (Nat.succ_lt_mono idx1 n) pf1) tl) A)); intros HH2. rewrite ps_one, Rmult_1_l in HH2. rewrite <- HH2. apply ps_proper; intros [??]; simpl. unfold pre_Ω, event_preimage, pre_event_inter; tauto. - generalize (product_sa_product p (ivector_ps i) Ω - (rv_preimage (fun tl => ivector_nth idx2 (lt_S_n idx2 n pf2) tl) B)); intros HH2. + (rv_preimage (fun tl => ivector_nth idx2 (proj2 (Nat.succ_lt_mono idx2 n) pf2) tl) B)); intros HH2. rewrite ps_one, Rmult_1_l in HH2. rewrite <- HH2. apply ps_proper; intros [??]; simpl. @@ -4419,7 +4419,7 @@ Section ps_sequence_product. simpl. unfold inf_cylinder_event, section_seq_event in e0. unfold inf_cylinder_event in H3. - pose (w := fun i => match lt_dec i N with + pose (w := fun i => match Compare_dec.lt_dec i N with | left pf => ivector_nth i pf x4 | right _ => inh end). diff --git a/coq/ProbTheory/ProductSpaceDep.v b/rocq/ProbTheory/ProductSpaceDep.v similarity index 100% rename from coq/ProbTheory/ProductSpaceDep.v rename to rocq/ProbTheory/ProductSpaceDep.v diff --git a/coq/ProbTheory/RandomVariable.v b/rocq/ProbTheory/RandomVariable.v similarity index 100% rename from coq/ProbTheory/RandomVariable.v rename to rocq/ProbTheory/RandomVariable.v diff --git a/coq/ProbTheory/RandomVariableFinite.v b/rocq/ProbTheory/RandomVariableFinite.v similarity index 99% rename from coq/ProbTheory/RandomVariableFinite.v rename to rocq/ProbTheory/RandomVariableFinite.v index fc87871e..db66c975 100644 --- a/coq/ProbTheory/RandomVariableFinite.v +++ b/rocq/ProbTheory/RandomVariableFinite.v @@ -688,11 +688,11 @@ Section fe. (Xn : nat -> Ts -> R) (Xn_pos : forall n, NonnegativeFunction (Xn n)) (is_fin_lim : - forall omega, is_finite (Lim_seq (sum_n (fun n => Xn n omega)))): - NonnegativeFunction (fun omega => Lim_seq (sum_n (fun n => Xn n omega))). + forall omega, is_finite (Lim_seq (@sum_n R_AbelianGroup (fun n => Xn n omega)))): + NonnegativeFunction (fun omega => Lim_seq (@sum_n R_AbelianGroup (fun n => Xn n omega))). Proof. unfold NonnegativeFunction in *; intros. - generalize (Lim_seq_pos (sum_n (fun n : nat => Xn n x))). + generalize (Lim_seq_pos (@sum_n R_AbelianGroup (fun n : nat => Xn n x))). rewrite <- is_fin_lim; simpl. intros; apply H. intros. @@ -726,7 +726,7 @@ Section fe. (Xn_rv : forall n, RandomVariable dom borel_sa (Xn n)) (isfe : forall n, IsFiniteExpectation (Xn n)) : forall (n:nat), - sum_n (fun n0 : nat => FiniteExpectation (Xn n0)) n = + @sum_n R_AbelianGroup (fun n0 : nat => FiniteExpectation (Xn n0)) n = FiniteExpectation (rvsum Xn n). Proof. intros. @@ -883,8 +883,8 @@ Lemma Fatou_FiniteExpectation apply H. intros. unfold rvsum, sum_n. - replace (sum_n_m (fun n1 : nat => Xn n1 x) 0 n0) with - (sum_n_m (fun n1 : nat => Xn n1 x) 0 n0 + 0) by lra. + replace (@sum_n_m R_AbelianGroup (fun n1 : nat => Xn n1 x) 0 n0) with + (@sum_n_m R_AbelianGroup (fun n1 : nat => Xn n1 x) 0 n0 + 0) by lra. rewrite sum_n_Sm; [|lia]. unfold plus; simpl. apply Rplus_le_compat_l. diff --git a/coq/ProbTheory/RandomVariableL2.v b/rocq/ProbTheory/RandomVariableL2.v similarity index 100% rename from coq/ProbTheory/RandomVariableL2.v rename to rocq/ProbTheory/RandomVariableL2.v diff --git a/coq/ProbTheory/RandomVariableLinf.v b/rocq/ProbTheory/RandomVariableLinf.v similarity index 99% rename from coq/ProbTheory/RandomVariableLinf.v rename to rocq/ProbTheory/RandomVariableLinf.v index 8815c594..7b3ec206 100644 --- a/coq/ProbTheory/RandomVariableLinf.v +++ b/rocq/ProbTheory/RandomVariableLinf.v @@ -1537,14 +1537,20 @@ Qed. LiRRVq_simpl. apply LiRRV_plus_inv. Qed. - - Definition LiRRVq_AbelianGroup_mixin : AbelianGroup.mixin_of LiRRVq - := AbelianGroup.Mixin LiRRVq LiRRVq_plus LiRRVq_opp LiRRVq_zero + + Definition LiRRVq_AbelianMonoid_mixin : AbelianMonoid.mixin_of LiRRVq + := AbelianMonoid.Mixin LiRRVq LiRRVq_plus LiRRVq_zero LiRRVq_plus_comm LiRRVq_plus_assoc - LiRRVq_plus_zero LiRRVq_plus_inv. + LiRRVq_plus_zero. + + Canonical LiRRVq_AbelianMonoid := + AbelianMonoid.Pack LiRRVq LiRRVq_AbelianMonoid_mixin LiRRVq. + + Definition LiRRVq_AbelianGroup_mixin : AbelianGroup.mixin_of LiRRVq_AbelianMonoid + := AbelianGroup.Mixin LiRRVq_AbelianMonoid LiRRVq_opp LiRRVq_plus_inv. Canonical LiRRVq_AbelianGroup := - AbelianGroup.Pack LiRRVq LiRRVq_AbelianGroup_mixin LiRRVq. + AbelianGroup.Pack LiRRVq (AbelianGroup.Class _ (LiRRVq_AbelianMonoid_mixin) LiRRVq_AbelianGroup_mixin) LiRRVq. Ltac LiRRVq_simpl ::= repeat match goal with @@ -1590,7 +1596,7 @@ Qed. LiRRVq_scale_plus_l LiRRVq_scale_plus_r. Canonical LiRRVq_ModuleSpace := - ModuleSpace.Pack R_Ring LiRRVq (ModuleSpace.Class R_Ring LiRRVq LiRRVq_AbelianGroup_mixin LiRRVq_ModuleSpace_mixin) LiRRVq. + ModuleSpace.Pack R_Ring LiRRVq (ModuleSpace.Class R_Ring LiRRVq _ LiRRVq_ModuleSpace_mixin) LiRRVq. Definition LiRRVq_norm : LiRRVq -> R := quot_rec LiRRV_norm_proper. diff --git a/coq/ProbTheory/RandomVariableLpNat.v b/rocq/ProbTheory/RandomVariableLpNat.v similarity index 98% rename from coq/ProbTheory/RandomVariableLpNat.v rename to rocq/ProbTheory/RandomVariableLpNat.v index ccf6d690..dd97852e 100644 --- a/coq/ProbTheory/RandomVariableLpNat.v +++ b/rocq/ProbTheory/RandomVariableLpNat.v @@ -343,9 +343,9 @@ Section Lp. * unfold Binomial.C. left; unfold Rdiv. apply Rmult_lt_0_compat. - -- apply lt_0_INR; apply lt_O_fact. + -- apply lt_0_INR; apply Factorial.lt_O_fact. -- apply Rinv_0_lt_compat. - apply Rmult_lt_0_compat; apply lt_0_INR; apply lt_O_fact. + apply Rmult_lt_0_compat; apply lt_0_INR; apply Factorial.lt_O_fact. * destruct (Rle_dec y x). -- replace (p) with (i + (p-i))%nat at 2. ++ rewrite Rdef_pow_add. @@ -573,8 +573,8 @@ Section Lp. Record LpRRV : Type := LpRRV_of { LpRRV_rv_X :> Ts -> R - ; LpRRV_rv :> RandomVariable dom borel_sa LpRRV_rv_X - ; LpRRV_lp :> IsLp p LpRRV_rv_X + ; LpRRV_rv ::> RandomVariable dom borel_sa LpRRV_rv_X + ; LpRRV_lp ::> IsLp p LpRRV_rv_X }. Global Existing Instance LpRRV_rv. @@ -869,13 +869,19 @@ Section Lp. apply LpRRV_plus_inv. Qed. - Definition LpRRVq_AbelianGroup_mixin : AbelianGroup.mixin_of LpRRVq - := AbelianGroup.Mixin LpRRVq LpRRVq_plus LpRRVq_opp LpRRVq_zero + Definition LpRRVq_AbelianMonoid_mixin : AbelianMonoid.mixin_of LpRRVq + := AbelianMonoid.Mixin LpRRVq LpRRVq_plus LpRRVq_zero LpRRVq_plus_comm LpRRVq_plus_assoc - LpRRVq_plus_zero LpRRVq_plus_inv. + LpRRVq_plus_zero. + + Canonical LpRRVq_AbelianMonoid := + AbelianMonoid.Pack LpRRVq LpRRVq_AbelianMonoid_mixin LpRRVq. + + Definition LpRRVq_AbelianGroup_mixin : AbelianGroup.mixin_of LpRRVq_AbelianMonoid + := AbelianGroup.Mixin LpRRVq_AbelianMonoid LpRRVq_opp LpRRVq_plus_inv. Canonical LpRRVq_AbelianGroup := - AbelianGroup.Pack LpRRVq LpRRVq_AbelianGroup_mixin LpRRVq. + AbelianGroup.Pack LpRRVq (AbelianGroup.Class _ (LpRRVq_AbelianMonoid_mixin) LpRRVq_AbelianGroup_mixin) LpRRVq. Ltac LpRRVq_simpl ::= @@ -922,7 +928,7 @@ Section Lp. LpRRVq_scale_plus_l LpRRVq_scale_plus_r. Canonical LpRRVq_ModuleSpace := - ModuleSpace.Pack R_Ring LpRRVq (ModuleSpace.Class R_Ring LpRRVq LpRRVq_AbelianGroup_mixin LpRRVq_ModuleSpace_mixin) LpRRVq. + ModuleSpace.Pack R_Ring LpRRVq (ModuleSpace.Class R_Ring LpRRVq _ LpRRVq_ModuleSpace_mixin) LpRRVq. End quot. @@ -1014,7 +1020,7 @@ Section Lp. Proof. repeat rewrite Rsqr_pow2. repeat rewrite <- pow_mult. - now rewrite mult_comm. + now rewrite Nat.mul_comm. Qed. Lemma pow_incr_inv (x y:R) (n : nat) : diff --git a/coq/ProbTheory/RandomVariableLpR.v b/rocq/ProbTheory/RandomVariableLpR.v similarity index 99% rename from coq/ProbTheory/RandomVariableLpR.v rename to rocq/ProbTheory/RandomVariableLpR.v index 7824b88b..d30038d5 100644 --- a/coq/ProbTheory/RandomVariableLpR.v +++ b/rocq/ProbTheory/RandomVariableLpR.v @@ -469,8 +469,8 @@ Qed. Record LpRRV : Type := LpRRV_of { LpRRV_rv_X :> Ts -> R - ; LpRRV_rv :> RandomVariable dom borel_sa LpRRV_rv_X - ; LpRRV_lp :> IsLp p LpRRV_rv_X + ; LpRRV_rv ::> RandomVariable dom borel_sa LpRRV_rv_X + ; LpRRV_lp ::> IsLp p LpRRV_rv_X }. Global Existing Instance LpRRV_rv. @@ -886,13 +886,19 @@ Qed. apply LpRRV_plus_inv. Qed. - Definition LpRRVq_AbelianGroup_mixin : AbelianGroup.mixin_of (LpRRVq p) - := AbelianGroup.Mixin (LpRRVq p) LpRRVq_plus LpRRVq_opp LpRRVq_zero + Definition LpRRVq_AbelianMonoid_mixin : AbelianMonoid.mixin_of (LpRRVq p) + := AbelianMonoid.Mixin (LpRRVq p) LpRRVq_plus LpRRVq_zero LpRRVq_plus_comm LpRRVq_plus_assoc - LpRRVq_plus_zero LpRRVq_plus_inv. + LpRRVq_plus_zero. + + Canonical LpRRVq_AbelianMonoid := + AbelianMonoid.Pack (LpRRVq p) LpRRVq_AbelianMonoid_mixin (LpRRVq p). + + Definition LpRRVq_AbelianGroup_mixin : AbelianGroup.mixin_of LpRRVq_AbelianMonoid + := AbelianGroup.Mixin LpRRVq_AbelianMonoid LpRRVq_opp LpRRVq_plus_inv. Canonical LpRRVq_AbelianGroup := - AbelianGroup.Pack (LpRRVq p) LpRRVq_AbelianGroup_mixin (LpRRVq p). + AbelianGroup.Pack (LpRRVq p) (AbelianGroup.Class _ (LpRRVq_AbelianMonoid_mixin) LpRRVq_AbelianGroup_mixin) (LpRRVq p). Ltac LpRRVq_simpl ::= repeat match goal with @@ -938,7 +944,7 @@ Qed. LpRRVq_scale_plus_l LpRRVq_scale_plus_r. Canonical LpRRVq_ModuleSpace := - ModuleSpace.Pack R_Ring (LpRRVq p) (ModuleSpace.Class R_Ring (LpRRVq p) LpRRVq_AbelianGroup_mixin LpRRVq_ModuleSpace_mixin) (LpRRVq p). + ModuleSpace.Pack R_Ring (LpRRVq p) (ModuleSpace.Class R_Ring (LpRRVq p) _ LpRRVq_ModuleSpace_mixin) (LpRRVq p). End quotnneg. End packednonneg. @@ -1409,7 +1415,7 @@ Qed. sum_n_m (fun k => pow c k) (S n) m = (pow c (S m) - pow c (S n))/(c-1). Proof. intros. - rewrite sum_n_m_sum_n; [|lia]. + rewrite (@sum_n_m_sum_n R_AbelianGroup); [|lia]. rewrite sum_geom; trivial. rewrite sum_geom; trivial. unfold minus, plus, opp; simpl. diff --git a/coq/ProbTheory/RbarExpectation.v b/rocq/ProbTheory/RbarExpectation.v similarity index 99% rename from coq/ProbTheory/RbarExpectation.v rename to rocq/ProbTheory/RbarExpectation.v index 43ad83df..4c02d671 100644 --- a/coq/ProbTheory/RbarExpectation.v +++ b/rocq/ProbTheory/RbarExpectation.v @@ -3673,8 +3673,8 @@ Theorem Dominated_convergence erewrite (Rbar_NonnegExpectation_ext _ _ H7) in le2. erewrite (Rbar_NonnegExpectation_ext _ _ H8) in le3. - rewrite (Rbar_FiniteExpectation_Rbar_NonnegExpectation _) in le2. - rewrite (Rbar_FiniteExpectation_Rbar_NonnegExpectation _) in le3. + erewrite (Rbar_FiniteExpectation_Rbar_NonnegExpectation _) in le2. + erewrite (Rbar_FiniteExpectation_Rbar_NonnegExpectation _) in le3. rewrite (ELimInf_proper _ (fun n => (Rbar_FiniteExpectation g) + (Rbar_FiniteExpectation (fn n)))) in le2. @@ -3772,6 +3772,10 @@ Theorem Dominated_convergence -- apply is_Elim_seq_const. -- apply lim1. -- destruct (f x); reflexivity. + + eauto. + + eauto. + + eauto. + + eauto. Qed. Theorem Dominated_convergence_almost @@ -4543,7 +4547,7 @@ Section rv_expressible. rewrite <- H4. apply is_lim_seq_spec in H1. destruct (H1 M). - destruct (le_dec x1 n). + destruct (Compare_dec.le_dec x1 n). * now apply H5. * assert (n <= x1)%nat by lia. apply Rle_lt_trans with (r2 := simple_approx Y x1 x). diff --git a/coq/ProbTheory/RealRandomVariable.v b/rocq/ProbTheory/RealRandomVariable.v similarity index 100% rename from coq/ProbTheory/RealRandomVariable.v rename to rocq/ProbTheory/RealRandomVariable.v diff --git a/coq/ProbTheory/RealVectorHilbert.v b/rocq/ProbTheory/RealVectorHilbert.v similarity index 98% rename from coq/ProbTheory/RealVectorHilbert.v rename to rocq/ProbTheory/RealVectorHilbert.v index 9891d148..63c23d53 100644 --- a/coq/ProbTheory/RealVectorHilbert.v +++ b/rocq/ProbTheory/RealVectorHilbert.v @@ -333,13 +333,19 @@ Section Rvector_defs. apply Rvector_inv_plus. Qed. - Definition Rvector_AbelianGroup_mixin : AbelianGroup.mixin_of (vector R n) - := AbelianGroup.Mixin (vector R n) Rvector_plus Rvector_opp Rvector_zero + Definition Rvector_AbelianMonoid_mixin : AbelianMonoid.mixin_of (vector R n) + := AbelianMonoid.Mixin (vector R n) Rvector_plus Rvector_zero Rvector_plus_comm Rvector_plus_assoc - Rvector_plus_zero Rvector_plus_inv. + Rvector_plus_zero. + + Canonical Rvector_AbelianMonoid := + AbelianMonoid.Pack (vector R n) Rvector_AbelianMonoid_mixin (vector R n). + + Definition Rvector_AbelianGroup_mixin : AbelianGroup.mixin_of (Rvector_AbelianMonoid) + := AbelianGroup.Mixin Rvector_AbelianMonoid Rvector_opp Rvector_plus_inv. Canonical Rvector_AbelianGroup := - AbelianGroup.Pack (vector R n) Rvector_AbelianGroup_mixin (vector R n). + AbelianGroup.Pack (vector R n) (AbelianGroup.Class _ (Rvector_AbelianMonoid_mixin) Rvector_AbelianGroup_mixin) (vector R n). Lemma Rvector_scale_scale (a b:R) (v:vector R n) : a .* (b .* v) = (a * b) .* v. @@ -390,7 +396,7 @@ Section Rvector_defs. Rvector_scale_plus_l Rvector_scale_plus_r. Canonical Rvector_ModuleSpace := - ModuleSpace.Pack R_Ring (vector R n) (ModuleSpace.Class R_Ring (vector R n) Rvector_AbelianGroup_mixin Rvector_ModuleSpace_mixin) (vector R n). + ModuleSpace.Pack R_Ring (vector R n) (ModuleSpace.Class R_Ring (vector R n) _ Rvector_ModuleSpace_mixin) (vector R n). Lemma Rvector_scale_inj (c:R) (x y:vector R n) : c <> 0%R -> c .* x = c .* y -> x = y. @@ -950,7 +956,7 @@ Section Rvector_defs. forall m (pf2:(m <= n)%nat), F (fun v : vector R n => forall (i : nat) (pf : (i < m)%nat), - P i (lt_le_trans _ _ _ pf pf2) (vector_nth i (lt_le_trans _ _ _ pf pf2) v)). + P i (Nat.lt_le_trans _ _ _ pf pf2) (vector_nth i (Nat.lt_le_trans _ _ _ pf pf2) v)). Proof. intros [???] FA. induction m; simpl; intros mle. @@ -964,8 +970,8 @@ Section Rvector_defs. forall (i : nat) (pf : (i < m)%nat), P i (Nat.lt_le_trans i (S m) n (pft _ pf) mle) (vector_nth i (Nat.lt_le_trans i (S m) n (pft _ pf) mle) v))). + apply filter_imp; intros. - generalize (lt_n_Sm_le _ _ pf); intros pf2. - apply le_lt_or_eq in pf2. + generalize (proj1 (Nat.lt_succ_r _ _) pf); intros pf2. + apply Nat.lt_eq_cases in pf2. destruct H. destruct pf2. * generalize (H0 _ H1). @@ -992,7 +998,7 @@ Section Rvector_defs. P i pf (vector_nth i pf v)). Proof. intros. - generalize (Filter_Forall_commute_aux _ _ H H0 n (le_refl n)). + generalize (Filter_Forall_commute_aux _ _ H H0 n (Nat.le_refl n)). destruct H. apply filter_imp; intros. specialize (H i pf). @@ -1064,7 +1070,7 @@ Section Rvector_defs. destruct H as [?[???]]. generalize filter_true. apply filter_imp; intros. - rewrite (vector_zero0 e (Rvector_lim (fun x : vector R n -> Prop => F x))). + rewrite (vector_zero0 e (Rvector_lim F)). rewrite (vector_zero0 e x). apply ball_center. - apply Rvector_lim_complete_pos. diff --git a/coq/ProbTheory/SigmaAlgebras.v b/rocq/ProbTheory/SigmaAlgebras.v similarity index 98% rename from coq/ProbTheory/SigmaAlgebras.v rename to rocq/ProbTheory/SigmaAlgebras.v index d1fe763e..75c87ab1 100644 --- a/coq/ProbTheory/SigmaAlgebras.v +++ b/rocq/ProbTheory/SigmaAlgebras.v @@ -7,7 +7,7 @@ Require Import Morphisms EquivDec Program. Require Import Utils DVector. Require Export Event. -Require Import Lia. +Require Import Arith Lia. Set Bullet Behavior "Strict Subproofs". @@ -1635,13 +1635,6 @@ Qed. reflexivity. Qed. - Lemma lt_S_n_S i1 n pf : - (Lt.lt_S_n i1 n (Lt.lt_n_S i1 n pf)) = pf. - Proof. - apply digit_pf_irrel. - Qed. - - Lemma generated_rectangle_proj {T} {n} (s : SigmaAlgebra T) (i : ivector (SigmaAlgebra T) n) (e : pre_event (ivector T n)) : sa_sigma (generated_sa (pre_event_set_ivector_product (ivector_map sa_sigma i))) e -> sa_sigma (generated_sa (pre_event_set_ivector_product (ivector_map sa_sigma (n:=S n) (s, i)))) (fun '(_, x₂) => e x₂). @@ -1668,8 +1661,11 @@ Qed. apply HH. -- apply H0; intros. destruct x0. - specialize (HH (S i0) (Lt.lt_n_S _ _ pf)); simpl in *. - now rewrite lt_S_n_S in HH. + assert (pf':S i0 < S n) by lia. + specialize (HH (S i0) pf'); simpl in *. + erewrite (ivector_nth_prf_irrelevance _ x). + erewrite (ivector_nth_prf_irrelevance _ i1). + apply HH. Qed. Lemma ivector_rectangles_generate_sa {n} {T} @@ -1761,7 +1757,7 @@ Qed. destruct x0. exists p. exists (fun v => ivector_Forall2 (fun a0 (x : T) => a0 x) i0 v). - generalize (H1 0 (NPeano.Nat.lt_0_succ n)); intros. + generalize (H1 0 (Nat.lt_0_succ n)); intros. simpl in H3. split; trivial. split. @@ -1770,23 +1766,27 @@ Qed. unfold pre_event_set_ivector_product. exists i0; split. -- intros. - specialize (H1 (S i1) (Lt.lt_n_S i1 n pf)). + specialize (H1 (S i1) (proj1 (Nat.succ_lt_mono i1 n) pf)). simpl in H1. - now rewrite lt_S_n_S in H1. + erewrite (ivector_nth_prf_irrelevance _ (ivector_map sa_sigma i)). + erewrite (ivector_nth_prf_irrelevance _ i0). + apply H1. -- intros ?. now rewrite <- ivector_Forall2_nth_iff. * rewrite H2. intros ?. destruct x0. split; intros. - generalize (H4 0 (NPeano.Nat.lt_0_succ n)); intros. + generalize (H4 0 (Nat.lt_0_succ n)); intros. simpl in H5. split; trivial. -- rewrite <- ivector_Forall2_nth_iff. intros. - specialize (H4 (S i2) (Lt.lt_n_S i2 n pf)). + specialize (H4 (S i2) ((proj1 (Nat.succ_lt_mono i2 n) pf))). simpl in H4. - now rewrite lt_S_n_S in H4. + erewrite (ivector_nth_prf_irrelevance _ i0). + erewrite (ivector_nth_prf_irrelevance _ i1). + apply H4. -- destruct H4. rewrite <- ivector_Forall2_nth_iff in H5. destruct i2. diff --git a/coq/ProbTheory/SimpleExpectation.v b/rocq/ProbTheory/SimpleExpectation.v similarity index 99% rename from coq/ProbTheory/SimpleExpectation.v rename to rocq/ProbTheory/SimpleExpectation.v index 635c8eb6..6cecb39a 100644 --- a/coq/ProbTheory/SimpleExpectation.v +++ b/rocq/ProbTheory/SimpleExpectation.v @@ -2443,8 +2443,8 @@ Section SimpleConditionalExpectation. unfold pre_list_collection in H2. assert (n'bound:(n' < length (map event_pre l))%nat). { - destruct (lt_dec n' (length (map event_pre l))%nat); trivial. - apply not_lt in n0. + destruct (Compare_dec.lt_dec n' (length (map event_pre l))%nat); trivial. + apply Nat.le_ngt in n0. unfold pre_event in H2. rewrite (nth_overflow (map event_pre l) pre_event_none n0) in H2. elim (Cn'ne H2). diff --git a/coq/ProbTheory/VectorConditionalExpectation.v b/rocq/ProbTheory/VectorConditionalExpectation.v similarity index 100% rename from coq/ProbTheory/VectorConditionalExpectation.v rename to rocq/ProbTheory/VectorConditionalExpectation.v diff --git a/coq/ProbTheory/VectorRandomVariable.v b/rocq/ProbTheory/VectorRandomVariable.v similarity index 99% rename from coq/ProbTheory/VectorRandomVariable.v rename to rocq/ProbTheory/VectorRandomVariable.v index beff87bb..e27896a3 100644 --- a/coq/ProbTheory/VectorRandomVariable.v +++ b/rocq/ProbTheory/VectorRandomVariable.v @@ -354,7 +354,7 @@ Section vector_ops. Qed. Lemma vecrvsum_rvsum {n} (f : Ts -> vector R n) : - rv_eq (vecrvsum f) (rvsum (fun i x => match lt_dec i n with + rv_eq (vecrvsum f) (rvsum (fun i x => match Compare_dec.lt_dec i n with | left pf => vector_nth i pf (f x) | right _ => 0%R end) @@ -365,8 +365,8 @@ Section vector_ops. destruct (f a); simpl. subst. rewrite list_sum_sum_n. - apply (@Hierarchy.sum_n_ext Hierarchy.R_AbelianGroup); intros. - destruct (lt_dec n (length x)). + apply (@Hierarchy.sum_n_ext Hierarchy.R_AbelianMonoid); intros. + destruct (Compare_dec.lt_dec n (length x)). - unfold vector_nth. match goal with [|- context [proj1_sig ?x]] => destruct x @@ -1800,7 +1800,7 @@ Section real_pullback. rewrite <- vector_Forall2_nth_iff. intros. rewrite vector_nth_const. - destruct (lt_dec i n). + destruct (Compare_dec.lt_dec i n). - rewrite vector_nth_add_to_end_prefix with (pf2 := l). now rewrite vector_nth_const. - assert (i = n)%nat by lia. @@ -1911,7 +1911,7 @@ Section real_pullback. rewrite <- vector_Forall2_nth_iff. intros. rewrite vector_nth_const. - destruct (lt_dec i n). + destruct (Compare_dec.lt_dec i n). - rewrite vector_nth_add_to_end_prefix with (pf2 := l). now rewrite vector_nth_const. - assert (i = n)%nat by lia. @@ -2012,7 +2012,7 @@ Section real_pullback. exists (vector_add_to_end x0 (vector_const pre_Ω n)). split; intros. -- rewrite vector_nth_map, vector_nth_const. - destruct (lt_dec i n). + destruct (Compare_dec.lt_dec i n). ++ rewrite vector_nth_add_to_end_prefix with (pf2 := l). rewrite vector_nth_const. apply sa_all. @@ -2021,7 +2021,7 @@ Section real_pullback. now rewrite vector_nth_add_to_end_suffix. -- intro z. split; intros. - ++ destruct (lt_dec i n). + ++ destruct (Compare_dec.lt_dec i n). ** rewrite vector_nth_add_to_end_prefix with (pf2 := l). rewrite vector_nth_const. apply I. @@ -2145,7 +2145,7 @@ Section almost. apply all_almost; intros ??. now apply vector_Forall2_nth_iff. + intros. - apply lt_dec. + apply Compare_dec.lt_dec. + unfold vecrvnth. intros. now repeat rewrite (vector_nth_ext _ _ pf2 pf1). diff --git a/coq/ProbTheory/lintp_wrapper.v b/rocq/ProbTheory/lintp_wrapper.v similarity index 100% rename from coq/ProbTheory/lintp_wrapper.v rename to rocq/ProbTheory/lintp_wrapper.v diff --git a/coq/QLearn/Bellman.v b/rocq/QLearn/Bellman.v similarity index 100% rename from coq/QLearn/Bellman.v rename to rocq/QLearn/Bellman.v diff --git a/coq/QLearn/Dvoretzky.v b/rocq/QLearn/Dvoretzky.v similarity index 98% rename from coq/QLearn/Dvoretzky.v rename to rocq/QLearn/Dvoretzky.v index ef72e92d..9a43cbcf 100644 --- a/coq/QLearn/Dvoretzky.v +++ b/rocq/QLearn/Dvoretzky.v @@ -372,7 +372,7 @@ Lemma Dvoretzky_rel (n:nat) (theta:R) (X Y : nat -> Ts -> R) part_prod_n a n m <= part_prod_n b n m. Proof. intros. - destruct (le_dec n m). + destruct (Compare_dec.le_dec n m). - pose (h := (m - n)%nat). replace (m) with (n + h)%nat by lia. apply part_prod_n_le_h. @@ -604,7 +604,7 @@ Qed. (sum_n_m (fun n0 : nat => -2 * a n0) (S n0) (n + S n0))). { intros. - rewrite <- sum_split; trivial; lia. + rewrite <- (@sum_split R_AbelianGroup); trivial; lia. } unfold plus in H13; simpl in H13. exists (-2*x1 - sum_n_m (fun n0 : nat => -2 * a n0) 0 n0 ). @@ -994,15 +994,15 @@ Section Derman_Sacks. - rewrite is_lim_seq_incr_n with (N := N). apply is_lim_seq_ext with (u := fun n => - (sum_n_m (fun j : nat => (delta j - c j) / B j) 0 (n + N)) - - (sum_n_m (fun j : nat => (delta j - c j) / B j) 0 (N-1))). + (@sum_n_m R_AbelianGroup (fun j : nat => (delta j - c j) / B j) 0 (n + N)) - + (@sum_n_m R_AbelianGroup (fun j : nat => (delta j - c j) / B j) 0 (N-1))). intros. - rewrite sum_split with (m := (N-1)%nat); try lia. + rewrite (@sum_split R_AbelianGroup) with (m := (N-1)%nat); try lia. unfold plus; simpl; ring_simplify. replace (S (N-1))%nat with N by lia; trivial. apply is_lim_seq_minus with (l1 := m_infty) - (l2 := sum_n_m (fun j : nat => (delta j - c j) / B j) 0 (N - 1)). + (l2 := @sum_n_m R_AbelianGroup (fun j : nat => (delta j - c j) / B j) 0 (N - 1)). + now rewrite is_lim_seq_incr_n with (N := N) in H13. + apply is_lim_seq_const. + now unfold is_Rbar_minus, is_Rbar_plus; simpl. @@ -2829,7 +2829,7 @@ Theorem Dvoretzky_DS_scale_prop (m n i:nat) pf1 pf2 ts : vector_nth i pf1 (DS_Xn_v X0 T Y m ts) = vector_nth i pf2 (DS_Xn_v X0 T Y n ts). Proof. - destruct (le_dec n m). + destruct (Compare_dec.le_dec n m). - now apply DS_Xn_v_same_prefix_le_helper. - symmetry. apply DS_Xn_v_same_prefix_le_helper. @@ -3776,15 +3776,14 @@ Lemma Dvoretzky_converge_W_alpha_beta_uniform (W w α β: nat -> Ts -> R) almostR2 prts Rbar_le (ConditionalExpectation prts (filt_sub n) (rvsqr (w n))) (const (Rsqr C)))) -> - almost prts (fun ω : Ts => is_lim_seq (sum_n(fun n : nat => α n ω)) p_infty) -> - almost prts (fun ω : Ts => is_lim_seq (sum_n (fun n : nat => β n ω)) p_infty) -> + almost prts (fun ω : Ts => is_lim_seq (@sum_n R_AbelianGroup (fun n : nat => α n ω)) p_infty) -> + almost prts (fun ω : Ts => is_lim_seq (@sum_n R_AbelianGroup (fun n : nat => β n ω)) p_infty) -> almost prts (fun ω => ex_series (fun n => Rsqr (β n ω))) -> - (exists epsilon : posreal, eventually (fun n => almostR2 prts Rbar_lt (fun ω => Lim_seq (sum_n (fun nn => rvsqr (β (nn+n)%nat) ω))) (const epsilon))) -> + (exists epsilon : posreal, eventually (fun n => almostR2 prts Rbar_lt (fun ω => Lim_seq (@sum_n R_AbelianGroup (fun nn => rvsqr (β (nn+n)%nat) ω))) (const epsilon))) -> (forall n, rv_eq (W (S n)) (rvplus (rvmult (rvminus (const 1) (α n)) (W n)) (rvmult (w n) (β n)))) -> almost _ (fun ω => is_lim_seq (fun n => W n ω) (Finite 0)). Proof. intros condexpw condexpw2 alpha_inf beta_inf beta_sqr [ϵ beta_bounded] (* W0 *) Wrel. - assert (svy1b: forall n : nat, IsFiniteExpectation prts (rvsqr (β n))). { intros. @@ -3792,7 +3791,10 @@ Proof. } eapply (@Dvoretzky_converge_W_alpha_beta_isf_seq_sum W w α β F isfilt filt_sub adaptZ adapt_alpha adapt_beta rvw); eauto. - + change (is_finite + (Lim_seq + (@sum_n R_AbelianGroup + (fun n : nat => @FiniteExpectation Ts dom prts (@rvsqr Ts (β n)) (svy1b n))))). generalize (sum_expectation prts (fun n => rvsqr (β n))); intros HH. assert (rv2 : forall n, RandomVariable dom borel_sa (rvsqr (β n))). { @@ -3812,7 +3814,7 @@ Proof. unfold A3'. unfold Series. apply Rbar_real_rv. - cut (RandomVariable dom Rbar_borel_sa (fun omega : Ts => ELim_seq (sum_n (fun n : nat => (β n omega)²)))). + cut (RandomVariable dom Rbar_borel_sa (fun omega : Ts => ELim_seq (@sum_n R_AbelianGroup (fun n : nat => (β n omega)²)))). { apply RandomVariable_proper; try reflexivity; intros ?. now rewrite <- Elim_seq_fin. @@ -3878,7 +3880,7 @@ Proof. specialize (betaN (S N)). cut_to betaN; try lia. - pose (btail := (rvplus (fun ω => sum_n (fun nn : nat => rvsqr (β nn) ω) N) + pose (btail := (rvplus (fun ω => @sum_n R_AbelianGroup (fun nn : nat => rvsqr (β nn) ω) N) (const (pos ϵ)))). assert (btail_rv : RandomVariable dom Rbar_borel_sa btail). @@ -3892,7 +3894,7 @@ Proof. apply IsFiniteExpectation_Rbar. apply Rbar_finexp_finexp. { - cut (RandomVariable dom Rbar_borel_sa (fun omega : Ts => ELim_seq (sum_n (fun n : nat => (β n omega)²)))). + cut (RandomVariable dom Rbar_borel_sa (fun omega : Ts => ELim_seq (@sum_n R_AbelianGroup (fun n : nat => (β n omega)²)))). { apply RandomVariable_proper; try reflexivity; intros ?. now rewrite <- Elim_seq_fin. @@ -3901,7 +3903,7 @@ Proof. } apply (Rbar_IsFiniteExpectation_nnf_bounded_almost _ _ btail). - intros ?. - generalize (Lim_seq_le (fun _ => 0) (sum_n (fun n : nat => (β n x)²))) + generalize (Lim_seq_le (fun _ => 0) (@sum_n R_AbelianGroup (fun n : nat => (β n x)²))) ; intros HHH. cut_to HHH. + rewrite Lim_seq_const in HHH. @@ -3916,13 +3918,13 @@ Proof. apply almost_impl. apply all_almost; intros ω bsqr_ex bsqr_bound. rewrite <- (Lim_seq_incr_n _ (S N)). - assert (eqq:Lim_seq (fun n : nat => sum_n (fun n0 : nat => (β n0 ω)²) (n + S N)) = + assert (eqq:Lim_seq (fun n : nat => @sum_n R_AbelianGroup (fun n0 : nat => (β n0 ω)²) (n + S N)) = - (Lim_seq (fun n => sum_n (fun nn : nat => rvsqr (β nn) ω) N + - (sum_n (fun nn : nat => rvsqr (β (nn + S N)%nat) ω) n)))). + (Lim_seq (fun n => @sum_n R_AbelianGroup (fun nn : nat => rvsqr (β nn) ω) N + + (@sum_n R_AbelianGroup (fun nn : nat => rvsqr (β (nn + S N)%nat) ω) n)))). { apply Lim_seq_ext; intros n. - generalize (sum_split (fun n0 : nat => (β n0 ω)²) 0 ((n + S N)%nat) N) + generalize (@sum_split R_AbelianGroup (fun n0 : nat => (β n0 ω)²) 0 ((n + S N)%nat) N) ; intros HHH. cut_to HHH; try lia. unfold sum_n. @@ -3931,12 +3933,19 @@ Proof. f_equal. now rewrite sum_shift. } + change ( Rbar_le (Lim_seq (fun n : nat => @sum_n R_AbelianGroup (fun n0 : nat => (β n0 ω)²) (n + S N))) + (rvplus (fun ω0 : Ts => @sum_n R_AbelianGroup (fun nn : nat => rvsqr (β nn) ω0) N) (const ϵ) ω)). rewrite eqq. rewrite Lim_seq_plus. + unfold rvplus. rewrite Lim_seq_const. - replace (Finite (sum_n (fun nn : nat => rvsqr (β nn) ω) N + const (pos ϵ) ω)) with - (Rbar_plus (sum_n (fun nn : nat => rvsqr (β nn) ω) N) (const (pos ϵ) ω)) by reflexivity. + replace (Finite + (Rplus + (@sum_n (AbelianGroup.AbelianMonoid R_AbelianGroup) + (fun nn : nat => @rvsqr Ts (β nn) ω) N) + (pos (@const posreal Ts ϵ ω)))) + with + (Rbar_plus (@sum_n R_AbelianGroup (fun nn : nat => rvsqr (β nn) ω) N) (const (pos ϵ) ω)) by reflexivity. apply Rbar_plus_le_compat. { apply Rbar_le_refl. } now apply Rbar_lt_le. @@ -4139,16 +4148,18 @@ Proof. + destruct n; [lia |]. rewrite <- (Lim_seq_incr_n (sum_n (fun n0 : nat => rvsqr (β n0) ω)) (S n)). apply Lim_seq_le; intros. - generalize (sum_split (fun n0 : nat => (β n0 ω)²) 0 ((n0 + S n)%nat) n) + generalize (@sum_split R_AbelianGroup (fun n0 : nat => (β n0 ω)²) 0 ((n0 + S n)%nat) n) ; intros HHH. - cut_to HHH; try lia. - unfold sum_n, rvsqr. - rewrite HHH. - unfold plus, rvsqr; simpl. - rewrite sum_shift. - cut (0 <= sum_n_m (fun n1 : nat => (β n1 ω)²) 0 n); try lra. - apply sum_n_m_pos; intros. - apply Rle_0_sqr. + cut_to HHH; try lia. + unfold sum_n, rvsqr. + etransitivity; [| right; symmetry; apply HHH]. + unfold plus, rvsqr; simpl. + rewrite sum_shift. + change (@sum_n_m R_AbelianGroup (fun nn : nat => (β (nn + S n)%nat ω)²) 0 n0 <= + @sum_n_m R_AbelianGroup (fun n1 : nat => (β n1 ω)²) 0 n + @sum_n_m R_AbelianGroup (fun n1 : nat => (β (n1 + S n)%nat ω)²) 0 n0). + cut (0 <= @sum_n_m R_AbelianGroup (fun n1 : nat => (β n1 ω)²) 0 n); try lra. + apply sum_n_m_pos; intros. + apply Rle_0_sqr. + simpl. apply Rmax_r. Qed. diff --git a/coq/QLearn/Tsitsiklis.v b/rocq/QLearn/Tsitsiklis.v similarity index 99% rename from coq/QLearn/Tsitsiklis.v rename to rocq/QLearn/Tsitsiklis.v index 3a235ba3..0cc2eca5 100644 --- a/coq/QLearn/Tsitsiklis.v +++ b/rocq/QLearn/Tsitsiklis.v @@ -202,7 +202,7 @@ Proof. } pose (tau_coll k t j := - match (le_dec j t) with + match (Compare_dec.le_dec j t) with | left pf => event_lt (rv := rvB j t pf) (F t) (B j) (INR k) | _ => Ω end). @@ -516,7 +516,7 @@ Proof. } pose (tau_coll k t j := - match (le_dec j t) with + match (Compare_dec.le_dec j t) with | left pf => event_lt (rv := rvB j t pf) (F t) (B j) (INR k) | _ => Ω end). @@ -3991,7 +3991,7 @@ Qed. { apply almost_bounded_forall. intros. - - apply lt_dec. + - apply Compare_dec.lt_dec. - intros. apply is_lim_seq_ext with (u := (fun k : nat => WW i pf1 k 0%nat x)); trivial. intros. @@ -4128,7 +4128,7 @@ Qed. destruct (H21 H22). exists ((1 + ε) * (G x0 x)). intros. - destruct (le_dec t x0). + destruct (Compare_dec.le_dec t x0). - apply Rle_trans with (r2 := M x0 x). + unfold M. apply Rmax_seq_map_monotone. @@ -5001,7 +5001,7 @@ Qed. { apply almost_bounded_forall. intros. - - apply lt_dec. + - apply Compare_dec.lt_dec. - intros. apply is_lim_seq_ext with (u := (fun k : nat => WW i pf1 k 0%nat x)); trivial. intros. @@ -5145,7 +5145,7 @@ Qed. destruct (H21 H22). exists ((1 + ε) * (G x0 x)). intros. - destruct (le_dec t x0). + destruct (Compare_dec.le_dec t x0). - apply Rle_trans with (r2 := M x0 x). + unfold M. apply Rmax_seq_map_monotone. @@ -6022,7 +6022,7 @@ Qed. { apply almost_bounded_forall. intros. - - apply lt_dec. + - apply Compare_dec.lt_dec. - intros. apply is_lim_seq_ext with (u := (fun k : nat => WW i pf1 k 0%nat x)); trivial. intros. @@ -6166,7 +6166,7 @@ Qed. destruct (H21 H22). exists ((1 + ε) * (G x0 x)). intros. - destruct (le_dec t x0). + destruct (Compare_dec.le_dec t x0). - apply Rle_trans with (r2 := M x0 x). + unfold M. apply Rmax_seq_map_monotone. @@ -6767,7 +6767,7 @@ Qed. { apply almost_bounded_forall. intros. - - apply le_dec. + - apply Compare_dec.le_dec. - intros. rewrite (digit_pf_irrel _ _ pf2 pf1). apply H14. @@ -7378,7 +7378,7 @@ Qed. { apply almost_bounded_forall. intros. - - apply le_dec. + - apply Compare_dec.le_dec. - intros. rewrite (digit_pf_irrel _ _ pf2 pf1). apply H14. @@ -7717,7 +7717,7 @@ Qed. apply H11. + apply Rle_ge, Rvector_max_abs_nonneg. - intros. - apply lt_dec. + apply Compare_dec.lt_dec. - intros i pf1 pf2 x. apply is_lim_seq_ext. intros. @@ -8001,7 +8001,7 @@ Qed. apply H11. + apply Rle_ge, Rvector_max_abs_nonneg. - intros. - apply lt_dec. + apply Compare_dec.lt_dec. - intros i pf1 pf2 x. apply is_lim_seq_ext. intros. @@ -8972,7 +8972,7 @@ Qed. * apply H14. lia. + intros. - apply le_dec. + apply Compare_dec.le_dec. + intros. apply H14. - revert alpha_betaprop; apply almost_impl. @@ -9038,7 +9038,7 @@ Qed. * apply H0. lia. + intros. - apply le_dec. + apply Compare_dec.le_dec. + intros. apply H14. - revert H1; apply almost_impl. diff --git a/coq/QLearn/infprod.v b/rocq/QLearn/infprod.v similarity index 97% rename from coq/QLearn/infprod.v rename to rocq/QLearn/infprod.v index 30431b92..6b82ec96 100644 --- a/coq/QLearn/infprod.v +++ b/rocq/QLearn/infprod.v @@ -1,4 +1,4 @@ -Require Import Reals Sums Lra Lia. +Require Import ZArith Reals Sums Lra Lia. (* Require Import Coquelicot.Hierarchy Coquelicot.Series Coquelicot.Lim_seq Coquelicot.Rbar.*) Require Import Coquelicot.Coquelicot. Require Import LibUtils. @@ -493,7 +493,7 @@ Proof. replace (S n2 - n1)%nat with ((S m - n1) + (S n2 - S m))%nat by lia. rewrite seq_plus. rewrite List.fold_right_app. - rewrite fold_right_plus_acc. + rewrite (@fold_right_plus_acc G). now replace (n1 + (S m - n1))%nat with (S m) by lia. Qed. @@ -534,7 +534,7 @@ Qed. now simpl. + intros. unfold sum_n. - rewrite sum_split with (m := (nk-1)%nat); try lia. + rewrite (@sum_split R_AbelianGroup) with (m := (nk-1)%nat); try lia. apply Rplus_eq_compat_l. replace (S (nk - 1)) with (nk) by lia. apply sum_n_m_shift. @@ -562,7 +562,8 @@ Qed. cut (ex_lim_seq (fun n : nat => sum_n_m α 0 (nk + S n) - sum_n_m α 0 nk)). { apply ex_lim_seq_ext; intros. - rewrite (sum_split_plus α 0 nk (S n)); try lia. + change ((@sum_n_m R_AbelianGroup α 0 (nk + S n) - @sum_n_m R_AbelianGroup α 0 nk) = @sum_n_m R_AbelianGroup α (S nk) (n + S nk)). + rewrite (@sum_split_plus R_AbelianGroup α 0 nk (S n) ltac:(lia) ltac:(lia)). unfold plus; simpl. field_simplify. f_equal. @@ -885,7 +886,7 @@ Proof. destruct H0 as [k H0]; destruct H0. exists k. split; trivial; intros. - destruct (lt_dec m n). + destruct (Compare_dec.lt_dec m n). + remember (n - S m)%nat as nm. replace (n) with (S m + nm)%nat; [|lia]. rewrite initial_seg_prod_n; trivial. @@ -948,7 +949,7 @@ Proof. specialize (IHk H1). apply Rle_trans with (r2 := part_prod_n (pos_sq_fun F) (m + k) n); trivial. replace (m + S k)%nat with (S (m+k)%nat) by lia. - destruct (le_gt_dec (S (m+k)) n). + destruct (Compare_dec.le_gt_dec (S (m+k)) n). + apply max_bounded1_pre_le; trivial. intros; apply pos_sq_bounded1; trivial. + rewrite (part_prod_n_1 (pos_sq_fun F) (S (m + k)%nat)) ; [|lia]. @@ -1110,7 +1111,7 @@ Section Dvoretsky. Theorem Dvoretzky4_0 (F: nat -> posreal) (sigma V : nat -> R) : (forall (n:nat), V (S n) <= (F n) * (V n) + (sigma n)) -> (forall (n:nat), - V (S n) <= sum_n (fun k => (sigma k)*(part_prod_n F (S k) n)) n + + V (S n) <= @sum_n R_AbelianGroup (fun k => (sigma k)*(part_prod_n F (S k) n)) n + (V 0%nat)*(part_prod_n F 0 n)). Proof. intros. @@ -1145,8 +1146,8 @@ Qed. Lemma sum_bound_prod_A (F : nat -> posreal) (sigma : nat -> R) (A : R) (n m:nat) : (forall r s, part_prod_n (pos_sq_fun F) r s <= A) -> - sum_n_m (fun k => (Rsqr (sigma k))*(part_prod_n (pos_sq_fun F) (S k) n)) (S m) n <= - (sum_n_m (fun k => Rsqr (sigma k)) (S m) n) * A. + @sum_n_m R_AbelianGroup (fun k => (Rsqr (sigma k))*(part_prod_n (pos_sq_fun F) (S k) n)) (S m) n <= + (@sum_n_m R_AbelianGroup (fun k => Rsqr (sigma k)) (S m) n) * A. Proof. intros. rewrite <- sum_n_m_mult_r with (a := A). @@ -1160,8 +1161,8 @@ Qed. Lemma sum_bound3_max (F : nat -> posreal) (sigma : nat -> R) (n m:nat) : (S m <= n)%nat -> - sum_n (fun k => (Rsqr (sigma k))*(part_prod_n (pos_sq_fun F) (S k) n)) m <= - (sum_n (fun k => (Rsqr (sigma k))) m) * (max_prod_fun (pos_sq_fun F) (S m) n). + @sum_n R_AbelianGroup (fun k => (Rsqr (sigma k))*(part_prod_n (pos_sq_fun F) (S k) n)) m <= + (@sum_n R_AbelianGroup (fun k => (Rsqr (sigma k))) m) * (max_prod_fun (pos_sq_fun F) (S m) n). Proof. intros. rewrite <- sum_n_mult_r with (a := (max_prod_fun (pos_sq_fun F) (S m) n)). @@ -1176,8 +1177,8 @@ Theorem Dvoretzky4_8_5 (F : nat -> posreal) (sigma V: nat -> R) (n m:nat) (A:R): (forall (n:nat), Rsqr (V (S n)) <= (pos_sq_fun F) n * Rsqr (V n) + Rsqr (sigma n)) -> (m Rsqr (V (S n)) <= - ( sum_n_m (fun k => Rsqr (sigma k)) (S m) n) * A + - (Rsqr (V 0%nat) + sum_n (fun k => (Rsqr (sigma k))) m) * + ( @sum_n_m R_AbelianGroup (fun k => Rsqr (sigma k)) (S m) n) * A + + (Rsqr (V 0%nat) + @sum_n R_AbelianGroup (fun k => (Rsqr (sigma k))) m) * (max_prod_fun (pos_sq_fun F) (S m) n). Proof. intros F1 Vsqle mn. @@ -1185,6 +1186,7 @@ Proof. intros. specialize (H Vsqle n). unfold sum_n in H. + rewrite (sum_split _ _ _ m) in H; trivial; [|lia]. generalize (sum_bound_prod_A F sigma A n m F1); intros. generalize (max_prod_le (pos_sq_fun F) 0 (S m) n); intros. @@ -1204,8 +1206,8 @@ Lemma sum_bound_prod_A_sigma1 (F : nat -> posreal) (sigma : nat -> R) (A : R) (n m:nat) : (forall r s, part_prod_n (pos_sq_fun F) r s <= A) -> (forall n, 0 <= sigma n) -> - sum_n_m (fun k => (sigma k)*(part_prod_n (pos_sq_fun F) (S k) n)) (S m) n <= - (sum_n_m sigma (S m) n) * A. + @sum_n_m R_AbelianGroup (fun k => (sigma k)*(part_prod_n (pos_sq_fun F) (S k) n)) (S m) n <= + (@sum_n_m R_AbelianGroup sigma (S m) n) * A. Proof. intros. rewrite <- sum_n_m_mult_r with (a := A). @@ -1218,8 +1220,8 @@ Qed. Lemma sum_bound3_max_sigma1 (F : nat -> posreal) (sigma : nat -> R) (n m:nat) : (S m <= n)%nat -> (forall n, 0 <= sigma n) -> - sum_n (fun k => (sigma k)*(part_prod_n (pos_sq_fun F) (S k) n)) m <= - (sum_n sigma m) * (max_prod_fun (pos_sq_fun F) (S m) n). + @sum_n R_AbelianGroup (fun k => (sigma k)*(part_prod_n (pos_sq_fun F) (S k) n)) m <= + (@sum_n R_AbelianGroup sigma m) * (max_prod_fun (pos_sq_fun F) (S m) n). Proof. intros. rewrite <- sum_n_mult_r with (a := (max_prod_fun (pos_sq_fun F) (S m) n)). @@ -1238,8 +1240,8 @@ Theorem Dvoretzky4_8_5_V1 (F : nat -> posreal) (sigma V: nat -> R) (n m:nat) (A: (forall (n:nat), 0 <= sigma n) -> (m V (S n) <= - (sum_n_m sigma (S m) n) * A + - (V 0%nat + sum_n sigma m) * + (@sum_n_m R_AbelianGroup sigma (S m) n) * A + + (V 0%nat + @sum_n R_AbelianGroup sigma m) * (max_prod_fun (pos_sq_fun F) (S m) n). Proof. intros F1 Vle Vpos sigma_pos mn. @@ -1269,12 +1271,12 @@ Theorem Dvoretzky4_8_5_1 (F : nat -> posreal) (sigma V: nat -> R) (n m:nat) (A s is_series (fun n => Rsqr (sigma n)) sigmasum -> (m Rsqr (V (S n)) <= - (sum_n_m (fun k => Rsqr (sigma k)) (S m) n) * A + + (@sum_n_m R_AbelianGroup (fun k => Rsqr (sigma k)) (S m) n) * A + (Rsqr (V 0%nat) + sigmasum) * (max_prod_fun (pos_sq_fun F) (S m) n). Proof. intros. generalize (Dvoretzky4_8_5 F sigma V n m A H H0 H2); intros. - assert (sum_n (fun k : nat => (sigma k)²) m <= sigmasum). + assert (@sum_n R_AbelianGroup (fun k : nat => (sigma k)²) m <= sigmasum). - assert (H1' := H1). apply is_series_unique in H1. assert (ex_series (fun k : nat => (sigma k)²)). @@ -1300,12 +1302,12 @@ Theorem Dvoretzky4_8_5_1_V1 (F : nat -> posreal) (sigma V: nat -> R) (n m:nat) ( is_series sigma sigmasum -> (m V (S n) <= - (sum_n_m sigma (S m) n) * A + + (@sum_n_m R_AbelianGroup sigma (S m) n) * A + (V 0%nat + sigmasum) * (max_prod_fun (pos_sq_fun F) (S m) n). Proof. intros. generalize (Dvoretzky4_8_5_V1 F sigma V n m A H H0 H2 H1 H4); intros. - assert (sum_n sigma m <= sigmasum). + assert (@sum_n R_AbelianGroup sigma m <= sigmasum). - assert (H3' := H3). apply is_series_unique in H3. assert (ex_series sigma). diff --git a/coq/QLearn/jaakkola_vector.v b/rocq/QLearn/jaakkola_vector.v similarity index 99% rename from coq/QLearn/jaakkola_vector.v rename to rocq/QLearn/jaakkola_vector.v index 44c3d33f..13faa82a 100644 --- a/coq/QLearn/jaakkola_vector.v +++ b/rocq/QLearn/jaakkola_vector.v @@ -2879,7 +2879,7 @@ Section jaakola_vector2. apply all_almost; intros ??. now apply lim_seq_maxabs0_b. - intros. - apply lt_dec. + apply Compare_dec.lt_dec. - intros. revert H0. apply is_lim_seq_ext. @@ -3030,16 +3030,16 @@ Section jaakola_vector2. repeat rewrite vector_map_create. repeat rewrite vector_nth_create. repeat rewrite vector_nth_map. - assert (pf3 : (i < (S N))%nat). + assert (pf3 : (0 + i < 0 + (S N))%nat). { lia. } simpl. repeat match goal with - [|- context [@vector_nth ?RR ?nn i (plus_lt_compat_l ?pi ?pn ?pc ?pp) ?v]] => - replace (@vector_nth RR nn i (plus_lt_compat_l pi pn pc pp) v) - with (@vector_nth RR nn i pf3 v) - by apply vector_nth_ext + [|- context [@vector_nth R (S N) i (@proj1 ?h1 ?h2 (Nat.add_lt_mono_l ?pi ?pn ?pc) ?pp) ?v]] => + replace (@vector_nth R (S N) i (@proj1 h1 h2 (Nat.add_lt_mono_l pi pn pc) pp) v) + with (@vector_nth R (S N) i pf3 v) + by now apply vector_nth_ext end. lra. } @@ -3274,7 +3274,7 @@ Section jaakola_vector2. apply H18. } apply almost_bounded_forall; intros. - - apply lt_dec. + - apply Compare_dec.lt_dec. - revert H19. apply is_lim_seq_ext. intros. @@ -4458,7 +4458,7 @@ Section jaakola_vector2. apply bounded_forall_almost with (n := i) (pf := pf) in H0. - apply H0. - intros. - apply lt_dec. + apply Compare_dec.lt_dec. - intros. erewrite vector_nth_ext; try apply H11. } @@ -4471,7 +4471,7 @@ Section jaakola_vector2. apply eventually_impl. apply all_eventually; intros. apply almost_bounded_forall. - + intros; apply lt_dec. + + intros; apply Compare_dec.lt_dec. + intros. erewrite vector_nth_ext; try apply H12. + apply H11. @@ -4486,7 +4486,7 @@ Section jaakola_vector2. apply bounded_forall_almost with (n := i) (pf := pf) in H1. - apply H1. - intros. - apply lt_dec. + apply Compare_dec.lt_dec. - intros. erewrite vector_nth_ext; try apply H11. } @@ -4499,7 +4499,7 @@ Section jaakola_vector2. apply eventually_impl. apply all_eventually; intros. apply almost_bounded_forall. - + intros; apply lt_dec. + + intros; apply Compare_dec.lt_dec. + intros. erewrite vector_nth_ext; try apply H12. + apply H11. @@ -5192,7 +5192,7 @@ Section jaakola_vector2. apply bounded_forall_almost with (n := i) (pf := pf) in H0. - apply H0. - intros. - apply lt_dec. + apply Compare_dec.lt_dec. - intros. erewrite vector_nth_ext; try apply H10. } @@ -5205,7 +5205,7 @@ Section jaakola_vector2. apply eventually_impl. apply all_eventually; intros. apply almost_bounded_forall. - + intros; apply lt_dec. + + intros; apply Compare_dec.lt_dec. + intros. erewrite vector_nth_ext; try apply H11. + apply H10. @@ -5220,7 +5220,7 @@ Section jaakola_vector2. apply bounded_forall_almost with (n := i) (pf := pf) in H1. - apply H1. - intros. - apply lt_dec. + apply Compare_dec.lt_dec. - intros. erewrite vector_nth_ext; try apply H10. } @@ -5233,7 +5233,7 @@ Section jaakola_vector2. apply eventually_impl. apply all_eventually; intros. apply almost_bounded_forall. - + intros; apply lt_dec. + + intros; apply Compare_dec.lt_dec. + intros. erewrite vector_nth_ext; try apply H11. + apply H10. @@ -5395,7 +5395,7 @@ Section jaakola_vector2. unfold rvscale. now rewrite Rmult_comm. - intros. - apply lt_dec. + apply Compare_dec.lt_dec. - intros. assert (FiniteConditionalExpectation prts sub (fun ω : Ts => vector_nth i pf1 (pos_Rvector_mult (f ω) W)) x = FiniteConditionalExpectation prts sub (fun ω : Ts => vector_nth i pf2 (pos_Rvector_mult (f ω) W)) x). @@ -5676,7 +5676,7 @@ Proof. apply Rle_abs. + apply almost_forall; intros. apply almost_bounded_forall; intros. - * apply lt_dec. + * apply Compare_dec.lt_dec. * eapply Rbar_le_trans; cycle 1. apply H3. apply slln.eq_Rbar_le. @@ -6221,7 +6221,7 @@ Section qlearn. intros ?. now apply sum_n_ext. - apply bounded_forall_almost. - + intros; apply lt_dec. + + intros; apply Compare_dec.lt_dec. + intros ????. apply ex_series_ext. intros. diff --git a/coq/QLearn/lim_add.v b/rocq/QLearn/lim_add.v similarity index 100% rename from coq/QLearn/lim_add.v rename to rocq/QLearn/lim_add.v diff --git a/coq/QLearn/qlearn.v b/rocq/QLearn/qlearn.v similarity index 99% rename from coq/QLearn/qlearn.v rename to rocq/QLearn/qlearn.v index 4d0abada..8d158b4c 100644 --- a/coq/QLearn/qlearn.v +++ b/rocq/QLearn/qlearn.v @@ -115,7 +115,7 @@ Section rv_expressible. split. + intros. generalize (H1 n); intros. - destruct (le_dec x0 n). + destruct (Compare_dec.le_dec x0 n). * specialize (H2 l). rewrite Rabs_lt_between in H2. lra. @@ -174,7 +174,7 @@ Section rv_expressible. specialize (H0 M). destruct H0. simpl. - destruct (le_dec x n). + destruct (Compare_dec.le_dec x n). * now apply H0. * assert (n < x)%nat by lia. generalize (increasing_seq f H n (x - n)%nat); intros. @@ -324,7 +324,7 @@ End rv_expressible. generalize (scal r y2) as d. intros. unfold minus. - rewrite opp_plus. + rewrite (@opp_plus X). rewrite plus_assoc. rewrite plus_assoc. f_equal. @@ -577,7 +577,7 @@ End rv_expressible. (forall n, 0 <= α n <= 1) -> (forall n, 0 <= (1-gamma)* α (n+k)%nat < 1) -> is_lim_seq α 0 -> - is_lim_seq (sum_n α) p_infty -> + is_lim_seq (@sum_n R_AbelianGroup α) p_infty -> is_lim_seq (fun n => prod_f_R0 (fun m => g_alpha gamma (α (m + k)%nat)) n) 0. Proof. intros. @@ -597,7 +597,7 @@ End rv_expressible. reflexivity. - unfold l1_divergent. apply is_lim_seq_ext with - (u := (fun m => (1-gamma) * (sum_n (fun n => α (n + k)%nat) m))). + (u := (fun m => (1-gamma) * (@sum_n R_AbelianGroup (fun n => α (n + k)%nat) m))). + intros. generalize (sum_n_mult_l (1-gamma) (fun n => α (n + k)%nat) n); intros. unfold Hierarchy.mult in H6; simpl in H6. @@ -613,9 +613,9 @@ End rv_expressible. lia. -- assert (k > 0)%nat by lia. apply is_lim_seq_ext with - (u := fun m => minus (sum_n α (m + k)%nat) (sum_n α (k-1)%nat)). + (u := fun m => minus (@sum_n R_AbelianGroup α (m + k)%nat) (@sum_n R_AbelianGroup α (k-1)%nat)). ++ intros. - rewrite <- sum_n_m_sum_n; trivial; try lia. + rewrite <- (@sum_n_m_sum_n R_AbelianGroup); trivial; try lia. replace (S (k-1)%nat) with (k) by lia. apply sum_n_m_shift. ++ apply is_lim_seq_minus with @@ -698,10 +698,10 @@ End rv_expressible. unfold ex_finite_lim_seq. exists (x - (sum_n g_α N)). apply is_lim_seq_ext_loc with - (u := (fun n => minus (sum_n g_α n) (sum_n g_α N))). + (u := (fun n => minus (@sum_n R_AbelianGroup g_α n) (@sum_n R_AbelianGroup g_α N))). exists N. intros. - rewrite sum_n_m_sum_n; trivial. + rewrite (@sum_n_m_sum_n R_AbelianGroup); trivial. apply is_lim_seq_minus'; trivial. apply is_lim_seq_const. unfold ex_finite_lim_seq in H4. @@ -742,7 +742,7 @@ End rv_expressible. assert (n = 0%nat ) by lia. rewrite H2. apply H0. - - destruct (le_dec n N). + - destruct (Compare_dec.le_dec n N). + apply IHN; trivial. intros; apply H; lia. rewrite sum_Sn in H0. @@ -2230,6 +2230,7 @@ End rv_expressible. ** rewrite (scal_minus_distr_r 1). unfold minus. rewrite <- plus_assoc. + change (xstar = plus (scal 1 xstar) (plus (opp (scal (α n) xstar)) (scal (α n) xstar))). rewrite plus_opp_l. rewrite plus_zero_r. now generalize (scal_one xstar). @@ -3436,7 +3437,9 @@ End rv_expressible. now rewrite IHn. + apply is_lim_seq_const. - intros. - rewrite minus_eq_zero, norm_zero. + rewrite minus_eq_zero. + change (norm (@zero R_NormedModule) <= gamma * norm (minus x y)). + rewrite norm_zero. apply Rmult_le_pos; try lra. apply norm_ge_0. - intros; apply prod_f_R0_proper; [|trivial]. diff --git a/coq/QLearn/qlearn_aux.v b/rocq/QLearn/qlearn_aux.v similarity index 100% rename from coq/QLearn/qlearn_aux.v rename to rocq/QLearn/qlearn_aux.v diff --git a/coq/QLearn/qlearn_redux.v b/rocq/QLearn/qlearn_redux.v similarity index 99% rename from coq/QLearn/qlearn_redux.v rename to rocq/QLearn/qlearn_redux.v index 0e8ac9a2..5f1882cc 100644 --- a/coq/QLearn/qlearn_redux.v +++ b/rocq/QLearn/qlearn_redux.v @@ -1906,7 +1906,7 @@ Lemma Dvoretzky_converge_Y (C : R) (Y : nat -> Ts -> R) (alpha : nat -> Ts -> R) {adaptY : IsAdapted borel_sa Y F} (adapt_alpha : IsAdapted borel_sa alpha F) (alpha_pos:forall n x, 0 <= alpha n x) (alpha_one:forall n x, 0 <= 1 - alpha n x ) : - almost prts (fun omega : Ts => Lim_seq.is_lim_seq (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n : nat => alpha n omega)) + almost prts (fun omega : Ts => Lim_seq.is_lim_seq (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n : nat => alpha n omega)) Rbar.p_infty) -> rv_eq (Y 0%nat) (const C) -> (forall n, rv_eq (Y (S n)) (rvplus (rvmult (rvminus (const 1) (alpha n)) (Y n)) (rvscale (gamma * C) (alpha n)))) -> @@ -2042,10 +2042,10 @@ Lemma Dvoretzky_converge_Z (Z BB: nat -> Ts -> R) (alpha : nat -> Ts -> R) (fun x : Ts => ConditionalExpectation.ConditionalExpectation prts (filt_sub n) (BB n) x = Rbar.Finite (const 0 x))) -> - almost prts (fun omega : Ts => Lim_seq.is_lim_seq (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n : nat => alpha n omega)) + almost prts (fun omega : Ts => Lim_seq.is_lim_seq (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n : nat => alpha n omega)) Rbar.p_infty) -> (exists (A2 : R), - almost prts (fun omega => Rbar.Rbar_lt (Lim_seq.Lim_seq (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n : nat => rvsqr (alpha n) omega))) (Rbar.Finite A2))) -> + almost prts (fun omega => Rbar.Rbar_lt (Lim_seq.Lim_seq (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n : nat => rvsqr (alpha n) omega))) (Rbar.Finite A2))) -> (exists (sigma : R), forall n, rv_le (rvsqr (BB n)) (const (Rsqr sigma))) -> rv_eq (Z 0%nat) (const 0) -> (forall n, rv_eq (Z (S n)) (rvplus (rvmult (rvminus (const 1) (alpha n)) (Z n)) (rvmult (BB n) (alpha n)))) -> @@ -2276,6 +2276,10 @@ Proof. typeclasses eauto. } specialize (H4 H5 svy1). + change (Rbar.is_finite + (Lim_seq.Lim_seq (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n : nat => FiniteExpectation prts (rvsqr (alpha n))))) +). + rewrite (Lim_seq.Lim_seq_ext _ _ H4). destruct alpha_sqr as [A2 alpha_sqr]. generalize (Dominated_convergence_almost @@ -2303,7 +2307,7 @@ Proof. unfold rvsum. left. generalize (Lim_seq_increasing_le - (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n0 : nat => rvsqr (alpha n0) x))); intros. + (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n0 : nat => rvsqr (alpha n0) x))); intros. cut_to H8. --- specialize (H8 n). generalize (Rbar.Rbar_le_lt_trans _ _ _ H8 H7); intros. @@ -2359,7 +2363,7 @@ Proof. simpl. unfold rvsum. rewrite Rabs_right. - ** generalize (Lim_seq_increasing_le (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n0 : nat => rvsqr (alpha n0) x))); intros. + ** generalize (Lim_seq_increasing_le (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n0 : nat => rvsqr (alpha n0) x))); intros. cut_to H8. --- specialize (H8 n). generalize (Rbar.Rbar_le_lt_trans _ _ _ H8 H7); intros. @@ -3406,10 +3410,10 @@ Lemma list_inter_prob_bound (l : list (event dom * R)) : (forall n, independent_sas prts (filt_sub n) (pullback_rv_sub dom cod (X (S n)) (rvX (S n)))) -> (forall n, FiniteExpectation prts (BB ∘ X n) = 0) -> - almost prts (fun omega : Ts => Lim_seq.is_lim_seq (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n : nat => alpha n omega)) + almost prts (fun omega : Ts => Lim_seq.is_lim_seq (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n : nat => alpha n omega)) Rbar.p_infty) -> (exists (A2 : R), - almost prts (fun omega => Rbar.Rbar_lt (Lim_seq.Lim_seq (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n : nat => rvsqr (alpha n) omega))) (Rbar.Finite A2))) -> + almost prts (fun omega => Rbar.Rbar_lt (Lim_seq.Lim_seq (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n : nat => rvsqr (alpha n) omega))) (Rbar.Finite A2))) -> (exists (sigma : R), forall n, rv_le (rvsqr (BB ∘ X n)) (const (Rsqr sigma))) -> rv_eq (Z 0%nat) (const 0) -> (forall n, rv_eq (Z (S n)) (rvplus (rvmult (rvminus (const 1) (alpha n)) (Z n)) (rvmult (BB ∘ X (S n)) (alpha n)))) -> @@ -3626,7 +3630,7 @@ Lemma list_inter_prob_bound (l : list (event dom * R)) : {isfe : forall n, IsFiniteExpectation prts (BB ∘ X n)} (alpha_pos:forall n x, 0 <= alpha n x) (alpha_one:forall n x, 0 <= 1 - alpha n x ) : - almost prts (fun omega : Ts => Lim_seq.is_lim_seq (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n : nat => alpha n omega)) + almost prts (fun omega : Ts => Lim_seq.is_lim_seq (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n : nat => alpha n omega)) Rbar.p_infty) -> gamma + eps < 1 -> gamma < 1 -> @@ -3641,7 +3645,7 @@ Lemma list_inter_prob_bound (l : list (event dom * R)) : (forall n, rv_eq (Z (S n)) (rvplus (rvmult (rvminus (const 1) (alpha n)) (Z n)) (rvmult (BB ∘ X (S n)) (alpha n)))) -> (exists (A2 : R), - almost prts (fun omega => Rbar.Rbar_lt (Lim_seq.Lim_seq (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n : nat => rvsqr (alpha n) omega))) (Rbar.Finite A2))) -> + almost prts (fun omega => Rbar.Rbar_lt (Lim_seq.Lim_seq (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n : nat => rvsqr (alpha n) omega))) (Rbar.Finite A2))) -> (exists (sigma : R), forall n, rv_le (rvsqr (BB ∘ X n)) (const (Rsqr sigma))) -> (forall N, (N >= tk)%nat -> forall omega, diff --git a/coq/QLearn/slln.v b/rocq/QLearn/slln.v similarity index 99% rename from coq/QLearn/slln.v rename to rocq/QLearn/slln.v index 9bf0ea5c..e3385f7a 100644 --- a/coq/QLearn/slln.v +++ b/rocq/QLearn/slln.v @@ -1,4 +1,3 @@ -Require Import Qreals. Require Import Lra Lia Reals RealAdd RandomVariableL2 Coquelicot.Coquelicot. Require Import Morphisms FiniteType List ListAdd Permutation infprod Almost NumberIso. Require Import Sums SimpleExpectation PushNeg. @@ -6,7 +5,6 @@ Require Import EquivDec. Require Import Classical. Require Import ClassicalChoice. Require Import IndefiniteDescription ClassicalDescription. -Require QArith. Require Import BorelSigmaAlgebra. Require Import utils.Utils. Require Import ConditionalExpectation. @@ -167,7 +165,7 @@ Lemma ash_6_1_2 {a x : nat -> R} {x0 : R}(ha : forall n, 0 <= a n) (hb1 : forall n, 0 < sum_f_R0 a n)(hb2 : is_lim_seq (sum_f_R0 a) p_infty) (hx : is_lim_seq x x0): is_lim_seq (fun n => (sum_f_R0 (fun j => a j * x j) n)/(sum_f_R0 a n)) x0. Proof. - pose (A := fun (n j : nat) => if (le_dec j n) then (a j)/(sum_f_R0 a n) else 0). + pose (A := fun (n j : nat) => if (Compare_dec.le_dec j n) then (a j)/(sum_f_R0 a n) else 0). assert (Apos: forall n j, 0 <= A n j). { intros. @@ -204,7 +202,7 @@ Lemma ash_6_1_2 {a x : nat -> R} {x0 : R}(ha : forall n, 0 <= a n) + now rewrite Lim_seq_const. + exists n; intros. rewrite sum_n_ext with - (b := (fun j : nat => (if le_dec j n then (((a j)*(x j)) / sum_f_R0 a n) else 0))). + (b := (fun j : nat => (if Compare_dec.le_dec j n then (((a j)*(x j)) / sum_f_R0 a n) else 0))). * rewrite <- sum_n_sum_f_clipped with (N := n); try lia. rewrite sum_n_Reals. unfold Rdiv. @@ -309,7 +307,7 @@ Proof. cut_to H3; trivial. * eapply (is_lim_seq_ext _ _ x0 _ H3). * intros. - destruct (lt_dec 0 n). + destruct (Compare_dec.lt_dec 0 n). -- specialize (hb1 (n-1)%nat). replace (S (n-1)) with (n) in hb1 by lia. lra. @@ -374,7 +372,7 @@ Lemma ash_6_1_3_strong1 {b x : nat -> R} (hb1 : forall n, 0 < b n <= b (S n)) (h (hx : ex_series x): is_lim_seq (fun n => (sum_n_m (fun j => b j * x j) 1 n)/(b n)) 0. Proof. - pose (bb := fun n => if (lt_dec 0 n) then (b n) else 0). + pose (bb := fun n => if (Compare_dec.lt_dec 0 n) then (b n) else 0). generalize (@ash_6_1_3 bb x); intros. cut_to H; trivial. - apply is_lim_seq_ext with (v := fun n => sum_n_m (fun j => b j * x j) 1 (S n) / (b (S n))) in H. @@ -1693,7 +1691,7 @@ Qed. rv_unfold. rewrite Rmax_list_app by now simpl. unfold Rmax. - rewrite plus_0_l. + rewrite Nat.add_0_l. destruct (Rle_dec (Rmax_list (map (fun n : nat => F n a) (seq 0 (S k)))) (F (S k) a)); trivial. } apply IsFiniteExpectation_case. @@ -2014,7 +2012,7 @@ Proof. unfold Sum, rvsum. rewrite sum_Sn. unfold plus. simpl. rewrite Rplus_comm. unfold Rminus; rewrite Rplus_assoc. - replace (sum_n (fun n0 : nat => X (n0+m)%nat w) j + - sum_n (fun n0 : nat => X (n0+m)%nat w) j) with 0 by lra. + replace (@sum_n R_AbelianGroup (fun n0 : nat => X (n0+m)%nat w) j + - @sum_n R_AbelianGroup (fun n0 : nat => X (n0+m)%nat w) j) with 0 by lra. rewrite Rplus_0_r. match_destr. - match_destr. @@ -2867,12 +2865,12 @@ Qed. Qed. Lemma sum_shift_diff (X : nat -> R) (m a : nat) : - sum_n X (a + S m) - sum_n X m = - sum_n (fun n0 : nat => X (n0 + S m)%nat) a. + @sum_n R_AbelianGroup X (a + S m) - @sum_n R_AbelianGroup X m = + @sum_n R_AbelianGroup (fun n0 : nat => X (n0 + S m)%nat) a. Proof. rewrite <- sum_n_m_shift. unfold sum_n. - rewrite (@sum_split _ _ _ _ m); try lia. + rewrite (@sum_split R_AbelianGroup _ _ _ m); try lia. unfold plus; simpl. lra. Qed. @@ -3997,7 +3995,7 @@ Qed. match_destr. } unfold independent_event_collection in indep. - destruct (in_dec NPeano.Nat.eq_dec 0%nat l). + destruct (in_dec Nat.eq_dec 0%nat l). - pose (ll := 1%nat :: map (fun n => match n with | 0%nat => n | S n' => S n @@ -4024,13 +4022,28 @@ Qed. etransitivity; [| etransitivity]; [| apply indep' |]. + apply ps_proper. unfold ll. - rewrite perm. - simpl. + etransitivity. + { apply list_inter_equivlist_proper. + apply Permutation_equivlist. + apply Permutation_map. + apply perm. + } + etransitivity; cycle 1. + { + apply list_inter_equivlist_proper. + apply Permutation_equivlist. + apply Permutation_map. + apply Permutation_cons; [reflexivity |]. + apply Permutation_map. + symmetry. + apply perm. + } + repeat rewrite map_cons. repeat rewrite list_inter_cons. rewrite event_inter_assoc. apply event_inter_proper. - * rewrite event_inter_comm. - apply H6. + * red; simpl. + now rewrite pre_event_inter_comm. * apply list_inter_proper. rewrite map_map. apply Forall2_map_f. @@ -4924,7 +4937,7 @@ Qed. - apply collection_take_preserves_disjoint. intros ????[??][??]; simpl in *. apply H. - apply le_antisym. + apply Nat.le_antisymm. + assert (INR n1 < INR (S n2)) by (rewrite S_INR; lra). apply INR_lt in H4. lia. @@ -5292,7 +5305,7 @@ Qed. apply Rbar_finite_eq. lra. -- intros. - rewrite sum_split with (m := n); try lia. + rewrite (@sum_split R_AbelianGroup) with (m := n); try lia. rewrite sum_n_m_ext_loc with (b := fun _ => zero). ++ rewrite sum_n_m_const_zero. rewrite plus_zero_l. @@ -5324,7 +5337,7 @@ Qed. apply ps_pos. + intros. unfold sum_n. - rewrite sum_split with (m := n); try lia. + rewrite (@sum_split R_AbelianGroup) with (m := n); try lia. reflexivity. Qed. @@ -5387,7 +5400,7 @@ Qed. -- replace (n0 + S (S a))%nat with (S (n0 + S a)) by lia. rewrite sum_Rbar_n_finite_sum_n. unfold sum_n. - rewrite sum_split with (m := a); try lia. + rewrite (@sum_split R_AbelianGroup) with (m := a); try lia. rewrite plus_comm. rewrite sum_n_m_ext_loc with (b := (fun _ => @zero R_AbelianGroup)). ++ rewrite sum_n_m_const_zero. @@ -5614,7 +5627,7 @@ Qed. forall n m, sum_n_m f m (m + n)%nat = sum_n (fun k => - (if (le_dec m k) then 1 else 0) * (f k)) + (if (Compare_dec.le_dec m k) then 1 else 0) * (f k)) (m + n)%nat. Proof. intros. @@ -5648,7 +5661,7 @@ Qed. forall n m, Rbar.Finite(sum_n_m f m (m + n)%nat) = sum_Rbar_n (fun k => - (if (le_dec m k) then 1 else 0) * (f k)) + (if (Compare_dec.le_dec m k) then 1 else 0) * (f k)) (S (m + n)). Proof. intros. @@ -5659,7 +5672,7 @@ Qed. Lemma sum_inv_sq_Elim : forall m, Rbar_le - (ELim_seq (sum_Rbar_n (fun j : nat => (if le_dec m j then 1 else 0) / (INR (S j))²))) + (ELim_seq (sum_Rbar_n (fun j : nat => (if Compare_dec.le_dec m j then 1 else 0) / (INR (S j))²))) (2 / (INR (S m))). Proof. intros. @@ -5825,7 +5838,7 @@ Qed. ELim_seq (sum_Rbar_n (fun n0 : nat => - (if le_dec n0 x then 1 else 0) * + (if Compare_dec.le_dec n0 x then 1 else 0) * (f n0))) = (sum_n f x). Proof. @@ -6576,7 +6589,7 @@ Qed. (ELim_seq (sum_Rbar_n (fun n0 : nat => - (if (le_dec n0 x) then 1 else 0) * + (if (Compare_dec.le_dec n0 x) then 1 else 0) * FiniteExpectation Prts (rvmult (rvsqr (X 0%nat)) (EventIndicator @@ -6595,7 +6608,7 @@ Qed. (event_inter (event_lt dom (rvabs (X 0%nat)) (INR i + 1)) (event_ge dom (rvabs (X 0%nat)) (INR i))))))) - (ELim_seq (sum_Rbar_n (fun j => (if (le_dec i j) then 1 else 0)/ (INR (S j))²))))). + (ELim_seq (sum_Rbar_n (fun j => (if (Compare_dec.le_dec i j) then 1 else 0)/ (INR (S j))²))))). ++ apply bounded_is_finite with (a := 0) (b := 2 * FiniteExpectation Prts (rvabs (X 0%nat))). --- apply ELim_seq_nneg; intros. apply sum_Rbar_n_pos; intros. @@ -6816,10 +6829,10 @@ Qed. (classic_dec (event_inter (event_lt dom (rvabs (X 0%nat)) (INR x + 1)) (event_ge dom (rvabs (X 0%nat)) (INR x))))))) - ((sum_Rbar_n (fun j : nat => (if le_dec x j then 1 else 0) / (INR (S j))²)) n))). + ((sum_Rbar_n (fun j : nat => (if Compare_dec.le_dec x j then 1 else 0) / (INR (S j))²)) n))). ** rewrite ELim_seq_scal_l; trivial. unfold Rdiv. - assert (is_finite (ELim_seq (sum_Rbar_n (fun j : nat => (if le_dec x j then 1 else 0) * / (INR (S j))²)))). + assert (is_finite (ELim_seq (sum_Rbar_n (fun j : nat => (if Compare_dec.le_dec x j then 1 else 0) * / (INR (S j))²)))). { apply bounded_is_finite with (a := 0) (b := 2 / INR (S x)). - apply ELim_seq_nneg; intros. @@ -6837,7 +6850,7 @@ Qed. now rewrite <- H12. ** intros. assert (forall j, - Rbar.Finite ((if le_dec x j then 1 else 0) * + Rbar.Finite ((if Compare_dec.le_dec x j then 1 else 0) * (FiniteExpectation Prts (rvmult (rvsqr (X 0%nat)) (EventIndicator @@ -6847,7 +6860,7 @@ Qed. (rvmult (rvsqr (X 0%nat)) (EventIndicator (classic_dec (event_inter (event_lt dom (rvabs (X 0%nat)) (INR x + 1)) (event_ge dom (rvabs (X 0%nat)) (INR x))))))) * - ((if le_dec x j then 1 else 0) / (INR (S j))²))). + ((if Compare_dec.le_dec x j then 1 else 0) / (INR (S j))²))). { intros. apply Rbar_finite_eq. @@ -6870,7 +6883,7 @@ Qed. (EventIndicator (classic_dec (event_inter (event_lt dom (rvabs (X 0%nat)) (INR x + 1)) (event_ge dom (rvabs (X 0%nat)) (INR x))))))) - ((if le_dec x x0 then 1 else 0) / (INR (S x0))²)). + ((if Compare_dec.le_dec x x0 then 1 else 0) / (INR (S x0))²)). +++ rewrite sum_n_scal_l. reflexivity. +++ intros. @@ -6897,7 +6910,7 @@ Qed. rewrite ELim_seq_ext with (v := (sum_Rbar_n (fun n0 : nat => - (if le_dec n0 x then 1 else 0) * + (if Compare_dec.le_dec n0 x then 1 else 0) * (FiniteExpectation Prts (rvmult (rvsqr (X 0%nat)) (EventIndicator diff --git a/coq/QLearn/sumtest.v b/rocq/QLearn/sumtest.v similarity index 100% rename from coq/QLearn/sumtest.v rename to rocq/QLearn/sumtest.v diff --git a/coq/QLearn/uniform_converge.v b/rocq/QLearn/uniform_converge.v similarity index 100% rename from coq/QLearn/uniform_converge.v rename to rocq/QLearn/uniform_converge.v diff --git a/coq/QLearn/vecslln.v b/rocq/QLearn/vecslln.v similarity index 99% rename from coq/QLearn/vecslln.v rename to rocq/QLearn/vecslln.v index be18ee32..45fd50fb 100644 --- a/coq/QLearn/vecslln.v +++ b/rocq/QLearn/vecslln.v @@ -1,4 +1,3 @@ -Require Import QArith.Qreals. Require Import Lra Lia Reals RealAdd RandomVariableL2 Coquelicot.Coquelicot. Require Import Morphisms FiniteType List ListAdd Permutation infprod Almost NumberIso. Require Import Sums SimpleExpectation PushNeg. @@ -348,7 +347,7 @@ Section conv_as. intros. apply almost_bounded_forall; trivial. - intros. - apply lt_dec. + apply Compare_dec.lt_dec. - intros. revert H0. apply is_lim_seq_ext. @@ -585,7 +584,7 @@ Section vec_cauchy. Lemma nth_exist_join {T} {size:nat} {P} (f:forall (i:nat) (pf:(i < size)%nat), exists (t:T), P i pf t) : exists (ts:list T), length ts = size /\ forall i pf d, P i pf (nth i ts d). Proof. - destruct (nth_exist_join_aux f _ (le_refl _)) as [??]; firstorder. + destruct (nth_exist_join_aux f _ (Nat.le_refl _)) as [??]; firstorder. Qed. Lemma Hnorm_vector0 (x : vector R 0) : hilbert.Hnorm x = 0. @@ -806,7 +805,7 @@ Section ash. (hb1 : forall n, 0 < sum_f_R0 a n)(hb2 : is_lim_seq (sum_f_R0 a) p_infty) (hx : is_lim_seq x x0): is_lim_seq (fun n => (sum_f_R0 (fun j => a j * x j) n)/(sum_f_R0 a n)) x0. Proof. - pose (A := fun (n j : nat) => if (le_dec j n) then (a j)/(sum_f_R0 a n) else 0). + pose (A := fun (n j : nat) => if (Compare_dec.le_dec j n) then (a j)/(sum_f_R0 a n) else 0). assert (Apos: forall n j, 0 <= A n j). { intros. @@ -843,7 +842,7 @@ Section ash. + now rewrite Lim_seq_const. + exists n; intros. rewrite sum_n_ext with - (b := (fun j : nat => (if le_dec j n then (((a j)*(x j)) / sum_f_R0 a n) else 0))). + (b := (fun j : nat => (if Compare_dec.le_dec j n then (((a j)*(x j)) / sum_f_R0 a n) else 0))). * rewrite <- sum_n_sum_f_clipped with (N := n); try lia. rewrite sum_n_Reals. unfold Rdiv. @@ -948,7 +947,7 @@ Section ash. cut_to H3; trivial. * eapply (is_lim_seq_ext _ _ x0 _ H3). * intros. - destruct (lt_dec 0 n). + destruct (Compare_dec.lt_dec 0 n). -- specialize (hb1 (n-1)%nat). replace (S (n-1)) with (n) in hb1 by lia. lra. @@ -1013,7 +1012,7 @@ Section ash. (hx : ex_series x): is_lim_seq (fun n => (sum_n_m (fun j => b j * x j) 1 n)/(b n)) 0. Proof. - pose (bb := fun n => if (lt_dec 0 n) then (b n) else 0). + pose (bb := fun n => if (Compare_dec.lt_dec 0 n) then (b n) else 0). generalize (@ash_6_1_3 bb x); intros. cut_to H; trivial. - apply is_lim_seq_ext with (v := fun n => sum_n_m (fun j => b j * x j) 1 (S n) / (b (S n))) in H. @@ -2611,7 +2610,7 @@ End ash. rv_unfold. rewrite Rmax_list_app by now simpl. unfold Rmax. - rewrite plus_0_l. + rewrite Nat.add_0_l. destruct (Rle_dec (Rmax_list (map (fun n : nat => F n a) (seq 0 (S k)))) (F (S k) a)); trivial. } apply IsFiniteExpectation_case. @@ -2668,7 +2667,7 @@ End ash. rv_unfold. rewrite Rmax_list_app by now simpl. unfold Rmax. - rewrite plus_0_l. + rewrite Nat.add_0_l. destruct (Rle_dec (Rmax_list (map (fun n : nat => F n a) (seq 0 (S k)))) (F (S k) a)); trivial. } apply IsFiniteExpectation_case; trivial. @@ -2695,7 +2694,7 @@ End ash. apply in_map_iff. exists j'; split; trivial. apply in_seq; lia. - - destruct (le_dec k' k). + - destruct (Compare_dec.le_dec k' k). + specialize (IHk l). eapply Rle_trans. * apply IHk. @@ -3257,7 +3256,7 @@ End ash. + apply cond_pos. + exists q. split. - * apply Rlt_Qlt. + * apply Qreals.Rlt_Qlt. unfold QArith_base.inject_Z. unfold Q2R. simpl. @@ -3270,7 +3269,7 @@ End ash. - intros [eps [epos HH]]. assert (qepspos: 0 < Q2R eps). { - apply Qlt_Rlt in epos. + apply Qreals.Qlt_Rlt in epos. now rewrite RMicromega.Q2R_0 in epos. } exists (mkposreal (Q2R eps) qepspos). @@ -3285,23 +3284,23 @@ End ash. destruct (Rlt_dec 0 (Q2R i)). - assert (QArith_base.Qlt {| QArith_base.Qnum := 0; QArith_base.Qden := 1 |} i). { - apply Rlt_Qlt. + apply Qreals.Rlt_Qlt. now rewrite RMicromega.Q2R_0. } eapply (sa_proper _ (fun omega => (forall N : nat, exists n m : nat, (n >= N)%nat /\ (m >= N)%nat /\ hilbert.Hnorm (minus (X n omega) (X m omega)) >= Q2R i))). - + firstorder. + + red; simpl; intuition. + apply sa_pre_countable_inter; intros N. now apply (vec_sa_sigma_not_cauchy X (mkposreal _ r)). - eapply sa_proper; try apply sa_none. assert (~ QArith_base.Qlt {| QArith_base.Qnum := 0; QArith_base.Qden := 1 |} i). { intros qlt. - apply Qlt_Rlt in qlt. + apply Qreals.Qlt_Rlt in qlt. now rewrite RMicromega.Q2R_0 in qlt. } - firstorder. + red; intuition. Qed. Lemma vec_sa_sigma_cauchy_descending {size:nat} (X : nat -> Ts -> vector R size )(eps : posreal) @@ -3660,7 +3659,7 @@ End ash. Qed. Lemma vec_sum_n_m_shift {size:nat} (α : nat -> vector R size) (k n0 : nat) : - sum_n_m α k (n0 + k)%nat = sum_n (fun n1 : nat => α (n1 + k)%nat) n0. + @sum_n_m Rvector_AbelianGroup α k (n0 + k)%nat = @sum_n Rvector_AbelianGroup (fun n1 : nat => α (n1 + k)%nat) n0. Proof. unfold sum_n. induction n0. @@ -3675,9 +3674,10 @@ End ash. Qed. Lemma vec_sum_shift_diff {size:nat} (X : nat -> vector R size) (m a : nat) : - minus (sum_n X (a + S m)) (sum_n X m) = - sum_n (fun n0 : nat => X (n0 + S m)%nat) a. + minus (@sum_n Rvector_AbelianGroup X (a + S m)) (@sum_n Rvector_AbelianGroup X m) = + @sum_n Rvector_AbelianGroup (fun n0 : nat => X (n0 + S m)%nat) a. Proof. + simpl. rewrite <- vec_sum_n_m_shift. unfold sum_n. rewrite (@sum_split _ _ _ _ m); try lia. @@ -3806,7 +3806,7 @@ End ash. rewrite minus_eq_zero in r. rewrite (@hilbert.norm_zero) in r. generalize (is_pos_div_2 eps); intros; lra. - + assert (n > N)%nat by (destruct H; try lia;firstorder). + + assert (n > N)%nat by (destruct H; try lia;intuition). exists (n - (S N))%nat. unfold vecrvminus. now replace (n - S N + S N)%nat with (n) by lia. diff --git a/coq/lib_utils/LibUtils.v b/rocq/lib_utils/LibUtils.v similarity index 100% rename from coq/lib_utils/LibUtils.v rename to rocq/lib_utils/LibUtils.v diff --git a/coq/lib_utils/LibUtilsAssoc.v b/rocq/lib_utils/LibUtilsAssoc.v similarity index 100% rename from coq/lib_utils/LibUtilsAssoc.v rename to rocq/lib_utils/LibUtilsAssoc.v diff --git a/coq/lib_utils/LibUtilsBag.v b/rocq/lib_utils/LibUtilsBag.v similarity index 98% rename from coq/lib_utils/LibUtilsBag.v rename to rocq/lib_utils/LibUtilsBag.v index 5ee53180..507f8c42 100644 --- a/coq/lib_utils/LibUtilsBag.v +++ b/rocq/lib_utils/LibUtilsBag.v @@ -17,8 +17,6 @@ (** This module provides support for bags (or multisets). *) Require Import Arith. -Require Import Min. -Require Import Max. Require Import ZArith. Require Import Lia. Require Import Permutation. @@ -843,8 +841,8 @@ Section Bag. generalize (mult x0 a); intro; auto with arith. rewrite H2. rewrite IHl; rewrite H2. - generalize (mult l a); intro; rewrite <- succ_min_distr; auto with arith. - rewrite H0; rewrite min_0_l; simpl. + generalize (mult l a); intro; rewrite <- Nat.succ_min_distr; auto with arith. + rewrite H0; rewrite Nat.min_0_l; simpl. elim (equiv_dec a a); intro; try congruence. rewrite IHl; rewrite H0; auto with arith. rewrite IHl. @@ -878,10 +876,10 @@ Section Bag. + rewrite H in *; clear H. case (equiv_dec x x); try congruence. intros. - apply gt_Sn_O. + auto with arith. + case (equiv_dec x a). * intros. - apply gt_Sn_O. + auto with arith. * intros. apply IHl. assumption. @@ -896,7 +894,7 @@ Section Bag. apply mult_pos_equiv_in. apply mult_pos_equiv_in in Hx. rewrite bdistinct_mult. - apply min_case; auto. + apply Nat.min_case; auto. Qed. Lemma bdistinct_nil {l:list A} : @@ -1033,7 +1031,7 @@ Section Bag. Proof. intros; revert s. induction t; simpl. - intro. rewrite <- minus_n_O; reflexivity. + intro. rewrite Nat.sub_0_r; reflexivity. intros; rewrite IHt. destruct (equiv_dec x a). rewrite e in *; clear e. @@ -1103,7 +1101,7 @@ Section Bag. forall (n n0:nat), n0 - (n0 - n) = min n0 n. Proof. induction n. - intros; rewrite min_0_r; rewrite <- minus_n_O; rewrite minus_diag; reflexivity. + intros; rewrite Nat.min_0_r; rewrite Nat.sub_0_r; rewrite Nat.sub_diag; reflexivity. intros. generalize eq_nat_dec; intro. elim (H n n0); intro; clear H. @@ -1120,7 +1118,7 @@ Section Bag. forall (n n0:nat), n0 + (n - n0) = max n0 n. Proof. induction n; intros. - rewrite max_0_r; auto with arith. + rewrite Nat.max_0_r; auto with arith. generalize eq_nat_dec; intro. elim (H n n0); intro; clear H. rewrite a in *; clear a. @@ -1171,7 +1169,7 @@ Section Bag. intro. rewrite (bmin_mult s t a). rewrite (bmin_mult t s a). - rewrite min_comm. + rewrite Nat.min_comm. reflexivity. Qed. @@ -1183,7 +1181,7 @@ Section Bag. intro. rewrite (bmax_mult s t a). rewrite (bmax_mult t s a). - rewrite max_comm. + rewrite Nat.max_comm. reflexivity. Qed. @@ -1255,7 +1253,7 @@ Section Bag. revert H IHl1 H0; generalize (mult l1 a); generalize (mult l2 a); intros. assert (n = 0). + apply min_zero with (n2 := 0); assumption. - + clear H; rewrite H1 in *; rewrite plus_0_r in *; assumption. + + clear H; rewrite H1 in *; rewrite Nat.add_0_r in *; assumption. - apply IHl1; assumption. Qed. @@ -1281,7 +1279,7 @@ Section Bag. rewrite H; assumption. - assert (n = 0). apply min_zero with (n2 := 0); assumption. - rewrite H; rewrite plus_0_r; assumption. + rewrite H; rewrite Nat.add_0_r; assumption. - assert (1 <= n). apply min_one_yields_one; assumption. assert (1 <= n0). @@ -1317,17 +1315,17 @@ Section Bag. (mult l1 x = 0) \/ (mult l2 x = 0) -> mult (l1 min-b l2) x = 0. Proof. induction l2; simpl; intros; rewrite bmin_mult; simpl. - rewrite min_0_r; reflexivity. + rewrite Nat.min_0_r; reflexivity. revert H; elim (equiv_dec x a); intros. rewrite a0 in *; clear a0. elim H; intro; clear H. rewrite H0 in *. generalize (mult l2 a); auto. rewrite H0 in *. - generalize (mult l1 a); intro; rewrite min_0_r; reflexivity. + generalize (mult l1 a); intro; rewrite Nat.min_0_r; reflexivity. elim H; intro; clear H; rewrite H0. generalize (mult l1 a); auto. - generalize (mult l1 a); intro; rewrite min_0_r; reflexivity. + generalize (mult l1 a); intro; rewrite Nat.min_0_r; reflexivity. Qed. Lemma bdistinct_over_bmin: @@ -1350,7 +1348,7 @@ Section Bag. assert (n = 0); try (apply (min_zero n 0); assumption). rewrite H0; auto with arith. assert (n0 = 0); try (apply (min_zero n0 0); assumption). - rewrite H0; rewrite min_0_r; auto with arith. + rewrite H0; rewrite Nat.min_0_r; auto with arith. simpl; generalize (compare_either n n0); intro. elim H0; intro; clear H0. assert (min n n0 = n). try (rewrite min_l; [reflexivity|assumption]). diff --git a/coq/lib_utils/LibUtilsBindings.v b/rocq/lib_utils/LibUtilsBindings.v similarity index 99% rename from coq/lib_utils/LibUtilsBindings.v rename to rocq/lib_utils/LibUtilsBindings.v index 22b7139c..28b739d4 100644 --- a/coq/lib_utils/LibUtilsBindings.v +++ b/rocq/lib_utils/LibUtilsBindings.v @@ -43,20 +43,21 @@ Require Import LibUtilsStringAdd. Section Bindings. Class ODT {K:Type} - := mkODT { ODT_eqdec:>EqDec K eq; + := mkODT { ODT_eqdec::EqDec K eq; ODT_lt:K -> K -> Prop; - ODT_lt_strorder:>StrictOrder ODT_lt; + ODT_lt_strorder::StrictOrder ODT_lt; ODT_lt_dec: forall (a b:K), {ODT_lt a b} + {~ODT_lt a b}; ODT_compare:K -> K -> comparison; ODT_compare_spec: forall x y : K, CompareSpec (eq x y) (ODT_lt x y) (ODT_lt y x) (ODT_compare x y) }. + Generalizable Variables K. Context `{odt:@ODT K}. Lemma ODT_lt_irr (k:K) : ~(ODT_lt k k). - Proof. + Proof. apply irreflexivity. Qed. diff --git a/coq/lib_utils/LibUtilsBindingsNat.v b/rocq/lib_utils/LibUtilsBindingsNat.v similarity index 92% rename from coq/lib_utils/LibUtilsBindingsNat.v rename to rocq/lib_utils/LibUtilsBindingsNat.v index bd820896..0de72e7d 100644 --- a/coq/lib_utils/LibUtilsBindingsNat.v +++ b/rocq/lib_utils/LibUtilsBindingsNat.v @@ -18,13 +18,12 @@ natural number. *) Require Import Arith. -Require Import NPeano. Require Import LibUtilsBindings. Section BindingsNat. Global Program Instance ODT_nat : (@ODT nat) - := mkODT _ _ lt Nat.lt_strorder lt_dec Nat.compare _. + := mkODT _ _ lt Nat.lt_strorder Compare_dec.lt_dec Nat.compare _. Next Obligation. simpl. apply Nat.compare_spec. diff --git a/coq/lib_utils/LibUtilsClosure.v b/rocq/lib_utils/LibUtilsClosure.v similarity index 100% rename from coq/lib_utils/LibUtilsClosure.v rename to rocq/lib_utils/LibUtilsClosure.v diff --git a/coq/lib_utils/LibUtilsCompat.v b/rocq/lib_utils/LibUtilsCompat.v similarity index 100% rename from coq/lib_utils/LibUtilsCompat.v rename to rocq/lib_utils/LibUtilsCompat.v diff --git a/coq/lib_utils/LibUtilsCoqLibAdd.v b/rocq/lib_utils/LibUtilsCoqLibAdd.v similarity index 98% rename from coq/lib_utils/LibUtilsCoqLibAdd.v rename to rocq/lib_utils/LibUtilsCoqLibAdd.v index ed65514a..d887466e 100644 --- a/coq/lib_utils/LibUtilsCoqLibAdd.v +++ b/rocq/lib_utils/LibUtilsCoqLibAdd.v @@ -30,7 +30,6 @@ Require Import EquivDec. Require Import Equivalence. Require Import Peano_dec. Require Import ZArith. -Require Import Zdigits. Require Import Znat. Require Import Recdef. Require Import Compare_dec. @@ -186,7 +185,7 @@ Section CoqLibAdd. Lemma nin_app_or (x:A) a b : (~ In x (a ++ b)) <-> (~ In x a /\ ~ In x b). Proof. - intuition. apply in_app_or in H0; intuition. + intuition (auto with *). apply in_app_or in H0; intuition. Qed. Lemma in_in_app_false {l l1 l2} : @@ -463,7 +462,7 @@ Section CoqLibAdd. (forallb f l1 = true /\ forallb f l2 = true). Proof. repeat rewrite forallb_forall. - intuition; rewrite in_app_iff in *; intuition. + intuition (auto with *); rewrite in_app_iff in *; intuition (auto with *). Qed. Lemma forallb_map {A B} f (mf:A->B) m : forallb f (map mf m) = forallb ((fun x => f (mf x))) m. @@ -575,7 +574,7 @@ Section CoqLibAdd. reflexivity. assert (exists (n3:nat), min (S n1) (S n2) = (S n3)). exists (min n1 n2). - rewrite Min.succ_min_distr; reflexivity. + rewrite Nat.succ_min_distr; reflexivity. elim H0; intros. congruence. Qed. @@ -622,8 +621,8 @@ Section CoqLibAdd. revert x0; induction l; simpl; intros; try lia. rewrite (IHl (n0 * f a + x0)); simpl. rewrite (fold_left_arith_dist1 (f a + 0)). - rewrite mult_plus_distr_l. - rewrite mult_plus_distr_l. + rewrite Nat.mul_add_distr_l. + rewrite Nat.mul_add_distr_l. lia. Qed. @@ -633,7 +632,7 @@ Section CoqLibAdd. Proof. generalize 0. revert x0; induction l; simpl; intros; try lia. - assert (f a + n = n + f a) by apply plus_comm. + assert (f a + n = n + f a) by apply Nat.add_comm. rewrite H; clear H. rewrite (IHl x0 (n + f a)); reflexivity. Qed. @@ -803,7 +802,7 @@ Section CoqLibAdd. Context {A:Type}. Function iter_cost (optim: A -> A) (cost: A -> nat) (p: A) { measure cost p } := let p' := optim p in - if lt_dec (cost p') (cost p) + if Compare_dec.lt_dec (cost p') (cost p) then iter_cost optim cost p' else p. Proof. @@ -1015,3 +1014,4 @@ Ltac string_eqdec_to_equiv := Ltac string_dec_to_equiv := replace string_dec with (equiv_dec (EqDec:=string_dec)) in * by trivial. + diff --git a/coq/lib_utils/LibUtilsDigits.v b/rocq/lib_utils/LibUtilsDigits.v similarity index 92% rename from coq/lib_utils/LibUtilsDigits.v rename to rocq/lib_utils/LibUtilsDigits.v index 49db3ea3..b52cd973 100644 --- a/coq/lib_utils/LibUtilsDigits.v +++ b/rocq/lib_utils/LibUtilsDigits.v @@ -61,13 +61,13 @@ Section prelude. generalize (refl_equal n). pattern n at 2 4 6 10, q; case q; [intro | intros m l e]. rewrite <- eq_rect_eq_nat; trivial. - contradiction (le_Sn_n m); rewrite <- e; assumption. + contradiction (Nat.nle_succ_diag_l m); rewrite <- e; assumption. replace (le_S n m p) with (eq_rect _ (fun n0 => n <= n0) (le_S n m p) _ (refl_equal (S m))). 2:reflexivity. generalize (refl_equal (S m)). pattern (S m) at 1 3 4 6, q; case q; [intro Heq | intros m0 l HeqS]. - contradiction (le_Sn_n m); rewrite Heq; assumption. + contradiction (Nat.nle_succ_diag_l m); rewrite Heq; assumption. injection HeqS; intro Heq; generalize l HeqS. rewrite <- Heq; intros; rewrite <- eq_rect_eq_nat. rewrite (IHp l0); reflexivity. @@ -157,7 +157,7 @@ Section Digits. split. - unfold digits_to_nat in e1. rewrite e1. - rewrite mult_comm. + rewrite Nat.mul_comm. rewrite <- Nat.div_mod; trivial. lia. - intros ? HH. destruct (rev x); simpl in * . @@ -167,7 +167,7 @@ Section Digits. simpl in *. rewrite <- Nat.div_exact by lia. rewrite <- e1. - rewrite mult_comm. + rewrite Nat.mul_comm. simpl. lia. + auto. @@ -192,8 +192,8 @@ Section Digits. induction l; simpl; trivial. intros. apply IHl. - apply plus_le_compat_r. - apply mult_le_compat_r. + apply Nat.add_le_mono_r. + apply Nat.mul_le_mono_r. trivial. Qed. @@ -209,9 +209,9 @@ Section Digits. transitivity (acc*base). * transitivity (acc * 1). { lia. } - apply mult_le_compat_l. + apply Nat.mul_le_mono_l. lia. - * apply le_plus_l. + * apply Nat.le_add_r. Qed. Lemma digits_to_nat_aux_bound l c: @@ -221,34 +221,28 @@ Section Digits. induction l; simpl. - split. + lia. - + destruct (mult_O_le c (base*1)). - * lia. - * rewrite mult_comm. - lia. + + lia. - intros. destruct (IHl (c * base + proj1_sig a)) as [le1 le2]. clear IHl. split. + rewrite <- le1. - rewrite mult_plus_distr_r. - rewrite mult_assoc. - apply le_plus_l. - + eapply lt_le_trans; [apply le2 | ]. - repeat rewrite mult_plus_distr_r. - repeat rewrite mult_assoc. + lia. + + eapply Nat.lt_le_trans; [apply le2 | ]. + repeat rewrite Nat.mul_add_distr_r. + repeat rewrite Nat.mul_assoc. repeat rewrite Nat.mul_1_l. - rewrite plus_assoc_reverse. - apply plus_le_compat_l. + rewrite <- Nat.add_assoc. + apply Nat.add_le_mono_l. replace (proj1_sig a * base ^ Datatypes.length l + base ^ Datatypes.length l) with ((proj1_sig a +1) * base ^ Datatypes.length l). - * apply mult_le_compat_r. + * apply Nat.mul_le_mono_r. destruct a; simpl. - rewrite plus_comm; simpl. - apply lt_le_S. - trivial. - * rewrite mult_plus_distr_r, Nat.mul_1_l; trivial. + rewrite Nat.add_comm; simpl. + auto with arith. + * rewrite Nat.mul_add_distr_r, Nat.mul_1_l; trivial. Qed. Lemma digits_to_nat_aux_acc_inj_helper1 a b c n1 n2 : @@ -260,71 +254,64 @@ Section Digits. Proof. intros ? ? ? lt1 ltn. assert (le12:c * base * base ^ n2 + 0 <= c * base * base ^ n2 + a * base ^ n2). - { apply plus_le_compat_l. + { apply Nat.add_le_mono_l. apply Peano.le_0_n. } - rewrite plus_0_r in le12. - rewrite mult_plus_distr_r in lt1. - eapply le_lt_trans in lt1; try eapply le12. + rewrite Nat.add_0_r in le12. + rewrite Nat.mul_add_distr_r in lt1. + eapply Nat.le_lt_trans in lt1; try eapply le12. assert (le13:(c * base + b + 1) * (base ^ n1) <= (c * base + base) * (base ^ n1 )). { - apply mult_le_compat_r. - rewrite plus_assoc_reverse. - apply plus_le_compat_l. + apply Nat.mul_le_mono_r. + rewrite <- Nat.add_assoc. + apply Nat.add_le_mono_l. lia. } - eapply lt_le_trans in le13; try eapply lt1. - rewrite (le_plus_minus n1 n2) in le13 by lia. + eapply Nat.lt_le_trans in le13; try eapply lt1. + rewrite <- (Nat.sub_add n1 n2) in le13 by lia. rewrite Nat.pow_add_r in le13. - rewrite mult_assoc in le13. + rewrite Nat.mul_assoc in le13. assert (le14:c*base+base <= c*base*base). { replace (c*base+base) with ((c+1)*base). - - apply mult_le_compat_r. - rewrite mult_comm. + - apply Nat.mul_le_mono. + rewrite Nat.mul_comm. destruct base. + lia. + simpl. - apply plus_le_compat_l. + apply Nat.add_le_mono_l. destruct n. lia. destruct c. lia. - apply lt_le_S. - replace 0 with (S n *0) by auto. - apply mult_lt_compat_l; lia. - - rewrite mult_plus_distr_r. - rewrite mult_1_l. + apply Nat.lt_succ_r. + auto with arith. + + reflexivity. + - rewrite Nat.mul_add_distr_r. + rewrite Nat.mul_1_l. trivial. } assert (le15:(c * base + base) * base ^ n1 <= (c * base * base) * base ^ n1). { - apply mult_le_compat_r. + apply Nat.mul_le_mono_r. auto. } - eapply lt_le_trans in le15; try eapply le13. + eapply Nat.lt_le_trans in le15; try eapply le13. assert (le16:c * base * base ^ n1 * base <= c * base * base ^ n1 * base ^ (n2 - n1)). { - apply mult_le_compat_l. + apply Nat.mul_le_mono_l. generalize (Nat.sub_gt _ _ ltn). destruct (n2-n1). - congruence. - simpl; intros _ . replace base with (base*base^0) at 1. - + apply mult_le_compat_l. + + apply Nat.mul_le_mono_l. apply Nat.pow_le_mono_r; lia. + simpl. - rewrite mult_1_r. + rewrite Nat.mul_1_r. trivial. } - eapply le_lt_trans in le16; try eapply le15. - replace (c * base * base ^ n1 * base) with - (c * base * base * base ^ n1) in le16. - - intuition. - - repeat rewrite mult_assoc_reverse. - f_equal. f_equal. - rewrite mult_comm. - trivial. + lia. Qed. Lemma digits_to_nat_aux_acc_inj_helper12 a b c n1 n2 : @@ -339,15 +326,14 @@ Section Digits. ; [ | eapply (digits_to_nat_aux_acc_inj_helper1 a b c n1 n2); eauto]. red in e; subst. simpl in *. - rewrite (le_plus_minus n1 n2) in lt1 by lia. + rewrite <- (Nat.sub_add n1 n2) in lt1 by lia. rewrite Nat.pow_add_r in lt1. - rewrite (mult_comm (base ^ n1)) in lt1. - rewrite mult_assoc in lt1. + rewrite Nat.mul_assoc in lt1. assert (le2:base*base^n1 <= a*base^(n2 - n1) * base ^ n1). { - apply mult_le_compat_r. + apply Nat.mul_le_mono_r. replace base with (1*base) at 1 by lia. - apply mult_le_compat. + apply Nat.mul_le_mono. - replace 1 with (1*1) by lia. simpl. lia. - simpl. @@ -355,13 +341,13 @@ Section Digits. + apply Nat.pow_le_mono_r; lia. + apply Nat.pow_1_r. } - eapply le_lt_trans in lt1; try eapply le2; clear le2. + eapply Nat.le_lt_trans in lt1; try eapply le2; clear le2. assert (le3:(b + 1) * base ^ n1 <= base * base^n1). { - apply mult_le_compat_r. + apply Nat.mul_le_mono_r. lia. } - eapply le_lt_trans in lt1; try eapply le3; clear le3. + eapply Nat.le_lt_trans in lt1; try eapply le3; clear le3. lia. Qed. @@ -370,11 +356,10 @@ Section Digits. ~ b < a. Proof. intros lt1 l2. - apply lt_not_le in lt1. + apply Nat.le_ngt in lt1. apply lt1. - apply mult_le_compat_r. - rewrite plus_assoc_reverse. - apply plus_le_compat_l. + apply Nat.lt_succ_r. + apply Nat.mul_le_mono_r. lia. Qed. @@ -386,21 +371,21 @@ Section Digits. ~ n2 < n1. Proof. intros ? ? ? lt1 l2. - apply lt_not_le in lt1. + apply Nat.le_ngt in lt1. apply lt1. - rewrite (le_plus_minus n2 n1) by lia. + apply Nat.lt_succ_r. + rewrite <- (Nat.sub_add n2 n1) by lia. rewrite Nat.pow_add_r. - rewrite (mult_comm a). - rewrite (mult_comm (b+1)). - rewrite <- mult_assoc. - apply mult_le_compat_l. + rewrite Nat.mul_assoc. + apply Nat.mul_le_mono_r. transitivity base; try lia. transitivity (base^1*a). - rewrite Nat.pow_1_r. transitivity (base * 1); try lia. - apply mult_le_compat_l. + apply Nat.mul_le_mono_l. lia. - - apply mult_le_compat_r. + - rewrite Nat.mul_comm. + apply Nat.mul_le_mono_l. apply Nat.pow_le_mono_r; lia. Qed. @@ -414,8 +399,8 @@ Section Digits. destruct (digits_to_nat_aux_bound l1 (c*base+a)) as [lb1 ub1]. destruct (digits_to_nat_aux_bound l2 (c*base+b)) as [lb2 ub2]. rewrite eqq1 in lb1,ub1. - eapply le_lt_trans in lb1; [ | eapply ub2]. - eapply le_lt_trans in lb2; [ | eapply ub1]. + eapply Nat.le_lt_trans in lb1; [ | eapply ub2]. + eapply Nat.le_lt_trans in lb2; [ | eapply ub1]. clear eqq1 ub1 ub2. revert lb1 lb2. generalize (Datatypes.length l1). @@ -448,8 +433,8 @@ Section Digits. destruct (digits_to_nat_aux_bound l1 (c*base+a)) as [lb1 ub1]. destruct (digits_to_nat_aux_bound l2 (c*base+b)) as [lb2 ub2]. rewrite eqq1 in lb1,ub1. - eapply le_lt_trans in lb1; [ | eapply ub2]. - eapply le_lt_trans in lb2; [ | eapply ub1]. + eapply Nat.le_lt_trans in lb1; [ | eapply ub2]. + eapply Nat.le_lt_trans in lb2; [ | eapply ub1]. clear eqq1 ub1 ub2. revert lb1 lb2. generalize (Datatypes.length l1). @@ -485,15 +470,13 @@ Section Digits. assert (le1:n * base <= n*1) by lia. assert (le2:n * base <= n*1) by lia. destruct n; [congruence|]. - apply mult_S_le_reg_l in le2. - lia. + apply Nat.mul_le_mono_pos_l in le2; lia. - generalize (digits_to_nat_aux_le l1 (n * base + proj1_sig a)); intros eqq. rewrite H0 in eqq. assert (le1:n * base <= n*1) by lia. assert (le2:n * base <= n*1) by lia. destruct n; [congruence|]. - apply mult_S_le_reg_l in le2. - lia. + apply Nat.mul_le_mono_pos_l in le2; lia. - assert (lt0:0 - match lt_dec n base with + match Compare_dec.lt_dec n base with | left pf => Some (exist _ n pf) | right _ => None end @@ -868,7 +848,7 @@ Section Digits. Proof. unfold char_to_digit. simpl. - destruct (lt_dec 0 base). + destruct (Compare_dec.lt_dec 0 base). - f_equal. apply digit_ext. simpl; trivial. @@ -934,7 +914,7 @@ Section Digits. destruct (ascii_dec a "0"%char). - subst. unfold char_to_digit in eqq; simpl in eqq. - destruct (lt_dec 0 base); [ | lia]. + destruct (Compare_dec.lt_dec 0 base); [ | lia]. inversion eqq; clear eqq; subst. case_eq (string_to_digits s) ; [intros ? eqq2 | intros eqq2] @@ -1288,21 +1268,21 @@ End Digits. Section Bases. Definition lt_decider (a b:nat) : - match lt_dec a b with + match Compare_dec.lt_dec a b with | left pf => lt a b | right _ => True end. Proof. - destruct (lt_dec); trivial. + destruct (Compare_dec.lt_dec); trivial. Defined. Definition le_decider (a b:nat) : - match le_dec a b with + match Compare_dec.le_dec a b with | left pf => le a b | right _ => True end. Proof. - destruct (le_dec); trivial. + destruct (Compare_dec.le_dec); trivial. Defined. (** ** Base 2 *) diff --git a/coq/lib_utils/LibUtilsFresh.v b/rocq/lib_utils/LibUtilsFresh.v similarity index 99% rename from coq/lib_utils/LibUtilsFresh.v rename to rocq/lib_utils/LibUtilsFresh.v index 754186d2..428a8dad 100644 --- a/coq/lib_utils/LibUtilsFresh.v +++ b/rocq/lib_utils/LibUtilsFresh.v @@ -20,7 +20,7 @@ as strings). *) Require Import String. Require Import List. Require Import Permutation. -Require Import Arith Min. +Require Import Arith. Require Import EquivDec. Require Import Morphisms. Require Import Lia. @@ -121,7 +121,7 @@ Section Fresh. Defined. Definition find_fresh_inj_f {A:Type} {dec:EqDec A eq} f (inj:forall x y, f x = f y -> x = y) (dom:list A) : A - := proj1_sig (find_bounded_S_nin_finds f dom (S (length dom)) (gt_Sn_n _) inj). + := proj1_sig (find_bounded_S_nin_finds f dom (S (length dom)) (Nat.lt_succ_diag_r _) inj). Lemma find_fresh_inj_f_fresh {A:Type} {dec:EqDec A eq} f (inj:forall x y, f x = f y -> x = y) (dom:list A) : ~ In (find_fresh_inj_f f inj dom) dom. diff --git a/coq/lib_utils/LibUtilsGroupByDomain.v b/rocq/lib_utils/LibUtilsGroupByDomain.v similarity index 100% rename from coq/lib_utils/LibUtilsGroupByDomain.v rename to rocq/lib_utils/LibUtilsGroupByDomain.v diff --git a/coq/lib_utils/LibUtilsInterleaved.v b/rocq/lib_utils/LibUtilsInterleaved.v similarity index 100% rename from coq/lib_utils/LibUtilsInterleaved.v rename to rocq/lib_utils/LibUtilsInterleaved.v diff --git a/coq/lib_utils/LibUtilsLattice.v b/rocq/lib_utils/LibUtilsLattice.v similarity index 98% rename from coq/lib_utils/LibUtilsLattice.v rename to rocq/lib_utils/LibUtilsLattice.v index 747300b4..902a9440 100644 --- a/coq/lib_utils/LibUtilsLattice.v +++ b/rocq/lib_utils/LibUtilsLattice.v @@ -54,8 +54,8 @@ Section Lattice. meet : A -> A -> A; join : A -> A -> A; - meet_morphism :> Proper (eqA ==> eqA ==> eqA) meet ; - join_morphism :> Proper (eqA ==> eqA ==> eqA) join ; + meet_morphism :: Proper (eqA ==> eqA ==> eqA) meet ; + join_morphism :: Proper (eqA ==> eqA ==> eqA) join ; meet_commutative : commutative eqA meet; meet_associative : associative eqA meet; diff --git a/coq/lib_utils/LibUtilsLift.v b/rocq/lib_utils/LibUtilsLift.v similarity index 100% rename from coq/lib_utils/LibUtilsLift.v rename to rocq/lib_utils/LibUtilsLift.v diff --git a/coq/lib_utils/LibUtilsLiftIterators.v b/rocq/lib_utils/LibUtilsLiftIterators.v similarity index 99% rename from coq/lib_utils/LibUtilsLiftIterators.v rename to rocq/lib_utils/LibUtilsLiftIterators.v index 4446062d..d5043bce 100644 --- a/coq/lib_utils/LibUtilsLiftIterators.v +++ b/rocq/lib_utils/LibUtilsLiftIterators.v @@ -17,8 +17,6 @@ (** This module provides support for monadic iterators over bags. *) Require Import Arith. -Require Import Min. -Require Import Max. Require Import Lia. Require Import Permutation. Require Import Equivalence. diff --git a/coq/lib_utils/LibUtilsListAdd.v b/rocq/lib_utils/LibUtilsListAdd.v similarity index 100% rename from coq/lib_utils/LibUtilsListAdd.v rename to rocq/lib_utils/LibUtilsListAdd.v diff --git a/coq/lib_utils/LibUtilsResult.v b/rocq/lib_utils/LibUtilsResult.v similarity index 100% rename from coq/lib_utils/LibUtilsResult.v rename to rocq/lib_utils/LibUtilsResult.v diff --git a/coq/lib_utils/LibUtilsSortingAdd.v b/rocq/lib_utils/LibUtilsSortingAdd.v similarity index 100% rename from coq/lib_utils/LibUtilsSortingAdd.v rename to rocq/lib_utils/LibUtilsSortingAdd.v diff --git a/coq/lib_utils/LibUtilsStringAdd.v b/rocq/lib_utils/LibUtilsStringAdd.v similarity index 99% rename from coq/lib_utils/LibUtilsStringAdd.v rename to rocq/lib_utils/LibUtilsStringAdd.v index 67d43375..3ea25391 100644 --- a/coq/lib_utils/LibUtilsStringAdd.v +++ b/rocq/lib_utils/LibUtilsStringAdd.v @@ -22,7 +22,6 @@ Require Import Ascii. Require Import List. Require Import String. Require Import Arith. -Require Import Min. Require Import Equivalence. Require Import EquivDec. Require Import Compare_dec. @@ -379,7 +378,7 @@ Section Prefix. rewrite <- (substring_all l) at 3. apply substring_le_prefix. replace (String.length l - 0) with (String.length l) by lia. - apply le_min_r. + apply Nat.le_min_r. Qed. Lemma in_of_append pre y l : diff --git a/coq/lib_utils/LibUtilsSublist.v b/rocq/lib_utils/LibUtilsSublist.v similarity index 100% rename from coq/lib_utils/LibUtilsSublist.v rename to rocq/lib_utils/LibUtilsSublist.v diff --git a/coq/lib_utils/README.md b/rocq/lib_utils/README.md similarity index 100% rename from coq/lib_utils/README.md rename to rocq/lib_utils/README.md diff --git a/coq/utils/Assoc.v b/rocq/utils/Assoc.v similarity index 100% rename from coq/utils/Assoc.v rename to rocq/utils/Assoc.v diff --git a/coq/utils/BasicUtils.v b/rocq/utils/BasicUtils.v similarity index 100% rename from coq/utils/BasicUtils.v rename to rocq/utils/BasicUtils.v diff --git a/coq/utils/ClassicUtils.v b/rocq/utils/ClassicUtils.v similarity index 100% rename from coq/utils/ClassicUtils.v rename to rocq/utils/ClassicUtils.v diff --git a/coq/utils/CoquelicotAdd.v b/rocq/utils/CoquelicotAdd.v similarity index 100% rename from coq/utils/CoquelicotAdd.v rename to rocq/utils/CoquelicotAdd.v diff --git a/coq/utils/DVector.v b/rocq/utils/DVector.v similarity index 98% rename from coq/utils/DVector.v rename to rocq/utils/DVector.v index 1c98f433..628d0edf 100644 --- a/coq/utils/DVector.v +++ b/rocq/utils/DVector.v @@ -211,7 +211,7 @@ Proof. apply In_nth_error in inn. destruct inn as [i eqq]. destruct x; simpl in *. - destruct (lt_dec i (length x)). + destruct (Compare_dec.lt_dec i (length x)). - subst. exists i, l. unfold vector_nth, proj1_sig. @@ -360,7 +360,7 @@ Lemma vector_list_create_shiftS (start len:nat) (f:(forall m, S start <= m -> m < S start + len -> T)%nat) : vector_list_create (S start) len f = - vector_list_create start len (fun x pf1 pf2 => f (S x)%nat (le_n_S _ _ pf1) (lt_n_S _ _ pf2)). + vector_list_create start len (fun x pf1 pf2 => f (S x)%nat (le_n_S _ _ pf1) (proj1 (Nat.succ_lt_mono _ _) pf2)). Proof. revert start f. induction len; simpl; trivial; intros. @@ -376,7 +376,7 @@ Lemma vector_list_create_shift0 (start len:nat) (f:(forall m, start <= m -> m < start + len -> T)%nat) : vector_list_create start len f = - vector_list_create 0 len (fun x _ pf2 => f (start+x)%nat (le_plus_l start _) (plus_lt_compat_l _ _ start pf2)). + vector_list_create 0 len (fun x _ pf2 => f (start+x)%nat (Nat.le_add_r start _) (proj1 (Nat.add_lt_mono_l _ _ start) pf2)). Proof. induction start; simpl. - apply vector_list_create_ext; intros. @@ -448,7 +448,7 @@ Lemma vector_nth_create (i : nat) (pf2: i < len) (f:(forall m, start <= m -> m < start + len -> T)%nat) : - vector_nth i pf2 (vector_create start len f) = f (start + i) (PeanoNat.Nat.le_add_r start i) (Plus.plus_lt_compat_l _ _ start pf2). + vector_nth i pf2 (vector_create start len f) = f (start + i) (PeanoNat.Nat.le_add_r start i) (proj1 (Nat.add_lt_mono_l _ _ start) pf2). Proof. unfold vector_nth, proj1_sig. repeat match_destr. @@ -507,7 +507,7 @@ Program Definition vector_zip {A B:Type} Next Obligation. rewrite combine_length. repeat rewrite vector_length. - now rewrite Min.min_idempotent. + now rewrite Nat.min_idempotent. Qed. (* move this *) @@ -907,7 +907,7 @@ Qed. Lemma vector_nthS {A} a i (l:list A) pf1 pf2 : (vector_nth (S i) pf1 (exist (fun l0 : list A => length l0 = S (length l)) (a :: l) pf2)) - = vector_nth i (lt_S_n _ _ pf1) (exist (fun l0 : list A => length l0 = length l) (l) (eq_add_S _ _ pf2)). + = vector_nth i (proj2 (Nat.succ_lt_mono _ _) pf1) (exist (fun l0 : list A => length l0 = length l) (l) (eq_add_S _ _ pf2)). Proof. unfold vector_nth, proj1_sig. repeat match_destr. @@ -1186,7 +1186,7 @@ Section ivector. | 0%nat => fun pf _ => False_rect _ (Nat.nlt_0_r _ pf) | S n' => match idx with | 0%nat => fun pf '(hd,tl) => hd - | S m' => fun pf '(hd,tl) => ivector_nth m' (lt_S_n _ _ pf) tl + | S m' => fun pf '(hd,tl) => ivector_nth m' (proj2 (Nat.succ_lt_mono _ _) pf) tl end end. @@ -1456,7 +1456,7 @@ Section ivector. + specialize (H 0 ((Nat.lt_0_succ _))). apply H. + apply IHn; intros. - specialize (H (S i1) ( (lt_n_S _ _ pf))). + specialize (H (S i1) ( (proj1 (Nat.succ_lt_mono _ _) pf))). simpl in H. rewrite (ivector_nth_prf_irrelevance _ i _ _ pf) in H. now rewrite (ivector_nth_prf_irrelevance _ i0 _ _ pf) in H. @@ -1488,7 +1488,7 @@ Qed. := match n as n' return (forall i (pf:(i ivector B n' with | 0%nat => fun _ => tt | S m => - fun f => (f 0%nat (Nat.lt_0_succ _) , ivector_create m (fun i pf => f (S i) (lt_n_S _ _ pf))) + fun f => (f 0%nat (Nat.lt_0_succ _) , ivector_create m (fun i pf => f (S i) (proj1 (Nat.succ_lt_mono _ _) pf))) end. Lemma ivector_nth_from_list {A} @@ -1576,7 +1576,7 @@ Qed. Lemma vector_nth_cons_S {A} {n} a (vec : vector A n) i pf : vector_nth (S i) pf (vector_cons a vec) = - vector_nth i (lt_S_n i n pf) vec. + vector_nth i (proj2 (Nat.succ_lt_mono i n) pf) vec. Proof. unfold vector_cons, vector_nth, proj1_sig. destruct vec. @@ -1749,7 +1749,7 @@ Section Sequence. Definition ivector_to_sequence {T} {n} (v : ivector T n) (default : T) : nat -> T := (fun i => - match lt_dec i n with + match Compare_dec.lt_dec i n with | left pf => ivector_nth i pf v | right _ => default end). diff --git a/coq/utils/ELim_Seq.v b/rocq/utils/ELim_Seq.v similarity index 99% rename from coq/utils/ELim_Seq.v rename to rocq/utils/ELim_Seq.v index b18226be..6481ae45 100644 --- a/coq/utils/ELim_Seq.v +++ b/rocq/utils/ELim_Seq.v @@ -172,7 +172,7 @@ Proof. - intros. exists 0%nat; intros; lia. - intros HH. - specialize (IHN (fun x pf => P x (lt_S _ _ pf))). + specialize (IHN (fun x pf => P x (Nat.lt_lt_succ_r _ _ pf))). cut_to IHN. + simpl in IHN. specialize (HH N (Nat.lt_succ_diag_r _)). diff --git a/coq/utils/FiniteType.v b/rocq/utils/FiniteType.v similarity index 97% rename from coq/utils/FiniteType.v rename to rocq/utils/FiniteType.v index 51f57628..4ae619a1 100644 --- a/coq/utils/FiniteType.v +++ b/rocq/utils/FiniteType.v @@ -1,5 +1,5 @@ Require Import ClassicalDescription. -Require Import FinFun List EquivDec FunctionalExtensionality Lia Eqdep_dec. +Require Import Arith FinFun List EquivDec FunctionalExtensionality Lia Eqdep_dec. Require Import LibUtils. Require Import Isomorphism ListAdd BasicUtils. @@ -202,13 +202,12 @@ Lemma fold_right_add_const {A:Type} c (l:list A) : Proof. induction l; simpl; trivial. rewrite IHl; simpl. - rewrite NPeano.Nat.mul_succ_r. - now rewrite NPeano.Nat.add_comm. + lia. Qed. Lemma fold_right_mult_const {A:Type} c (l:list A) : fold_right Nat.mul 1 - (map (fun _ => c) l) = NPeano.Nat.pow c (length l). + (map (fun _ => c) l) = Nat.pow c (length l). Proof. induction l; simpl; trivial. now rewrite IHl; simpl. @@ -232,12 +231,12 @@ Proof. (FiniteType_fun_dep_elems_aux fin_elms0 (fun (x0 : A) (_ : In_strong x0 fin_elms0) => fin_elms)))) by (intros; now rewrite map_length). rewrite fold_right_add_const. - rewrite NPeano.Nat.mul_comm. + rewrite Nat.mul_comm. now rewrite <- IHfin_elms0. Qed. Lemma FiniteType_fun_size {A:Type} {dec:EqDec A eq} {B:Type} (finA:FiniteType A) (finB:FiniteType B) - : length (@fin_elms _ (FiniteType_fun finA finB)) = NPeano.pow (length (@fin_elms _ finB)) (length (@fin_elms _ finA)). + : length (@fin_elms _ (FiniteType_fun finA finB)) = Nat.pow (length (@fin_elms _ finB)) (length (@fin_elms _ finA)). Proof. unfold FiniteType_fun. rewrite FiniteType_fun_dep_size. @@ -507,7 +506,7 @@ Qed. intros index eqq. match_destr_in eqq. - invcs eqq. - rewrite PeanoNat.Nat.sub_diag; simpl. + rewrite Nat.sub_diag; simpl. congruence. - specialize (IHl _ eqq). generalize (find_index_aux_bounds eqq); intros le1. @@ -523,7 +522,7 @@ Qed. Proof. intros HH. specialize (find_index_aux_correct HH). - now rewrite PeanoNat.Nat.sub_0_r. + now rewrite Nat.sub_0_r. Qed. Lemma find_index_aux_first {l:list A} {a} {index:nat} {n} : @@ -537,7 +536,7 @@ Qed. intros index eqq. match_destr_in eqq. - invcs eqq. - rewrite PeanoNat.Nat.sub_diag; simpl. + rewrite Nat.sub_diag; simpl. lia. - specialize (IHl _ eqq). intros [|]; simpl; intros eqq2. @@ -768,10 +767,11 @@ Section countableType. (find_index' x l = length l)%nat. Proof. generalize (find_index'_in x l); intros. - destruct (Lt.le_lt_or_eq _ _ (find_index'_le x l)) - ; firstorder. - - lia. - - intros inn. + destruct (proj1 (Nat.lt_eq_cases (find_index' x l) (length l))). + - apply find_index'_le. + - firstorder; lia. + - firstorder. + intros inn. apply H in inn. lia. Qed. diff --git a/coq/utils/FiniteTypeVector.v b/rocq/utils/FiniteTypeVector.v similarity index 100% rename from coq/utils/FiniteTypeVector.v rename to rocq/utils/FiniteTypeVector.v diff --git a/coq/utils/Isomorphism.v b/rocq/utils/Isomorphism.v similarity index 100% rename from coq/utils/Isomorphism.v rename to rocq/utils/Isomorphism.v diff --git a/coq/utils/ListAdd.v b/rocq/utils/ListAdd.v similarity index 99% rename from coq/utils/ListAdd.v rename to rocq/utils/ListAdd.v index 50ccebea..7cac708b 100644 --- a/coq/utils/ListAdd.v +++ b/rocq/utils/ListAdd.v @@ -398,15 +398,15 @@ Section fp. rewrite Forall_forall in H. apply H. apply nth_In. - now apply lt_S_n. + auto with arith. + right. rewrite Forall_forall in H. symmetry. apply H. apply nth_In. - now apply lt_S_n. - + apply lt_S_n in n1bound. - apply lt_S_n in n2bound. + auto with arith. + + apply Nat.succ_lt_mono in n1bound. + apply Nat.succ_lt_mono in n2bound. destruct (IHFP _ _ n1bound n2bound) as [?|?]; auto. Qed. @@ -2036,7 +2036,7 @@ Section cross_product. apply Forall_impl; intros. lia. + - rewrite <- Plus.plus_Snm_nSm. + replace (n + S (length l0)) with (S (n + length l0)) by auto with arith. specialize (IHl0 (map (fun '(a0, b0) => b0 ++ [a0]) (list_prod a acc)) (S n)). apply IHl0. rewrite Forall_map; simpl. @@ -2212,12 +2212,12 @@ Qed. Lemma list_max_in l : l <> nil -> In (list_max l) l. Proof. induction l; simpl; [eauto |]; intros _. - destruct (Max.max_dec a (list_max l)) + destruct (Nat.max_dec a (list_max l)) ; rewrite e in * ; eauto. destruct l. - simpl in e. - rewrite Max.max_0_r in e; simpl + rewrite Nat.max_0_r in e; simpl ; eauto. - right; apply IHl; congruence. Qed. diff --git a/coq/utils/NumberIso.v b/rocq/utils/NumberIso.v similarity index 99% rename from coq/utils/NumberIso.v rename to rocq/utils/NumberIso.v index caad723b..e56c1d41 100644 --- a/coq/utils/NumberIso.v +++ b/rocq/utils/NumberIso.v @@ -1,4 +1,4 @@ -Require Import BinNums BinNat Nat List. +Require Import ZArith BinNums BinNat Nat List. Require Import Lia. Require Import LibUtils Isomorphism PairEncoding. diff --git a/coq/utils/PairEncoding.v b/rocq/utils/PairEncoding.v similarity index 99% rename from coq/utils/PairEncoding.v rename to rocq/utils/PairEncoding.v index f6fcf493..4acab6ea 100644 --- a/coq/utils/PairEncoding.v +++ b/rocq/utils/PairEncoding.v @@ -1,4 +1,4 @@ -Require Import BinNums Nat List BinInt. +Require Import Arith BinNums Nat List BinInt. Require Import Lia. Require Import LibUtils Isomorphism ListAdd. @@ -1255,7 +1255,7 @@ Global Instance nat_pair_encoder : Isomorphism (nat*nat) nat := pair_encoder nat now subst. - destruct (Compare_dec.le_dec c1 c). + specialize (IHc l). - eapply Le.le_trans. + eapply Nat.le_trans. apply IHc. rewrite seq_S with (len := (S c)). rewrite map_app, list_max_app. diff --git a/coq/utils/PushNeg.v b/rocq/utils/PushNeg.v similarity index 100% rename from coq/utils/PushNeg.v rename to rocq/utils/PushNeg.v diff --git a/coq/utils/Quotient.v b/rocq/utils/Quotient.v similarity index 100% rename from coq/utils/Quotient.v rename to rocq/utils/Quotient.v diff --git a/coq/utils/RbarAdd.v b/rocq/utils/RbarAdd.v similarity index 97% rename from coq/utils/RbarAdd.v rename to rocq/utils/RbarAdd.v index 201baa26..5c33c7b2 100644 --- a/coq/utils/RbarAdd.v +++ b/rocq/utils/RbarAdd.v @@ -218,7 +218,7 @@ Section sums. unfold sum_Rbar_n; intros. induction n; [simpl; f_equal; lra | ]. rewrite seq_Sn. - rewrite plus_0_l. + rewrite Nat.add_0_l. repeat rewrite map_app. repeat rewrite list_Rbar_sum_nneg_plus; simpl @@ -895,8 +895,8 @@ Qed. + now rewrite iso_b_f. + apply in_seq. split; [lia |]. - rewrite plus_0_l. - apply le_lt_n_Sm. + rewrite Nat.add_0_l. + apply Nat.lt_succ_r. destruct H1. apply in_seq in H1. apply in_seq in H2. @@ -1616,7 +1616,7 @@ Proof. split. + now apply classic_min_of_some in H0. + intros. - apply NPeano.Nat.nlt_ge; intros nlt. + apply Nat.nlt_ge; intros nlt. eapply classic_min_of_some_first in H0; try apply nlt. tauto. - intros. @@ -1624,7 +1624,7 @@ Proof. apply zerotails in H. apply is_lim_seq_spec in H. simpl in H. - elimtype False. + exfalso. destruct (H ϵ) as [N ?]. elim (H1 N). red; intros. @@ -1752,10 +1752,10 @@ Proof. Qed. Definition zerotails_incr_mult (a : nat -> R) (pf:ex_series a) n : R - := Series (fun n0 : nat => if le_dec (S (zerotails_eps2k_fun a pf n0)) n then 1 else 0). + := Series (fun n0 : nat => if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf n0)) n then 1 else 0). Lemma zerotails_incr_mult_ex (a : nat -> R) (pf:ex_series a) n : - ex_series (fun n0 : nat => if le_dec (S (zerotails_eps2k_fun a pf n0)) n then 1 else 0). + ex_series (fun n0 : nat => if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf n0)) n then 1 else 0). Proof. apply (ex_series_incr_n _ n). apply (ex_series_ext (fun _ => 0)). @@ -1769,7 +1769,7 @@ Proof. Qed. Lemma zerotails_incr_mult_trunc (a : nat -> R) (pf:ex_series a) n : - zerotails_incr_mult a pf n = sum_n (fun n0 : nat => if le_dec (S (zerotails_eps2k_fun a pf n0)) n then 1 else 0) n. + zerotails_incr_mult a pf n = sum_n (fun n0 : nat => if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf n0)) n then 1 else 0) n. Proof. unfold zerotails_incr_mult. apply is_series_unique. @@ -1777,11 +1777,12 @@ Proof. apply is_lim_seq_spec. simpl; intros. exists n; intros. - generalize (sum_n_m_sum_n (fun n1 : nat => if le_dec (S (zerotails_eps2k_fun a pf n1)) n then 1 else 0) n n0). + generalize (sum_n_m_sum_n (fun n1 : nat => if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf n1)) n then 1 else 0) n n0 H). match goal with [|- context [minus ?x ?y]] => replace (minus x y) with (x - y) by reflexivity end; simpl. - intros HH; rewrite <- HH; trivial. + intros HH. + eapply Rle_lt_trans; [right; apply (symmetry (f_equal Rabs HH)) |]. erewrite (sum_n_m_ext_loc _ (fun _ => zero)). - rewrite sum_n_m_const_zero. unfold zero; simpl. @@ -1826,7 +1827,7 @@ Proof. assert (0 <= sum_f_R0' (fun x : nat => if - le_dec (S (zerotails_eps2k_fun a pf (S (x + S n)))) + Compare_dec.le_dec (S (zerotails_eps2k_fun a pf (S (x + S n)))) (n + S (zerotails_eps2k_fun a pf m)) then 1 else 0) (zerotails_eps2k_fun a pf m) @@ -1941,13 +1942,13 @@ Proof. transitivity ( Series (fun k : nat => Series (fun n : nat => a n * - if le_dec (S (zerotails_eps2k_fun a pf k)) n + if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf k)) n then 1 else 0))). { apply Series_ext; intros. rewrite (Series_incr_n_aux (fun n0 : nat => - a n0 * (if le_dec (S (zerotails_eps2k_fun a pf n))%nat n0 then 1 else 0)) + a n0 * (if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf n))%nat n0 then 1 else 0)) (S (zerotails_eps2k_fun a pf n))). - apply Series_ext; intros. match_destr. @@ -1965,7 +1966,7 @@ Proof. (fun n : nat => Series (fun k : nat => - a n * (if le_dec (S (zerotails_eps2k_fun a pf k)) n then 1 else 0)))). + a n * (if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf k)) n then 1 else 0)))). { apply Series_nneg_nested_swap. - intros. @@ -1983,7 +1984,7 @@ Proof. unfold pointwise_relation; simpl; intros. rewrite <- ELim_seq_incr_1. rewrite ELim_seq_ext with - (v := (fun n => Finite (sum_n (fun j : nat => a j * (if le_dec (S (zerotails_eps2k_fun a pf a0)) j then 1 else 0)) n))). + (v := (fun n => Finite (sum_n (fun j : nat => a j * (if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf a0)) j then 1 else 0)) n))). rewrite Elim_seq_fin. rewrite <- ELim_seq_incr_1. rewrite ELim_seq_ext with @@ -1993,7 +1994,7 @@ Proof. rewrite ex_series_Lim_seq. -- rewrite (Series_incr_n_aux (fun n0 : nat => - a n0 * (if le_dec (S (zerotails_eps2k_fun a pf a0))%nat n0 then 1 else 0)) + a n0 * (if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf a0))%nat n0 then 1 else 0)) (S (zerotails_eps2k_fun a pf a0))). ++ apply Rbar_finite_eq. apply Series_ext; intros. @@ -2011,7 +2012,7 @@ Proof. (fun x => a ((S (zerotails_eps2k_fun a pf a0)) + x)%nat)). ++ intros; f_equal; lia. ++ now apply ex_series_incr_n. - -- apply (ex_series_le (fun j : nat => a j * (if le_dec (S (zerotails_eps2k_fun a pf a0)) j then 1 else 0)) a); trivial. + -- apply (ex_series_le (fun j : nat => a j * (if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf a0)) j then 1 else 0)) a); trivial. intros. unfold norm; simpl. unfold abs; simpl. @@ -2068,7 +2069,7 @@ Qed. - intros. simpl. destruct H2. - destruct (le_dec x n). + destruct (Compare_dec.le_dec x n). + now apply H2. + specialize (H2 x). cut_to H2; try lia; simpl in H2. @@ -2085,7 +2086,7 @@ Qed. split; try lia. assert (Rbar_le (f x) (f (Init.Nat.max x N))). { - destruct (le_dec N x). + destruct (Compare_dec.le_dec N x). - rewrite Nat.max_l; try lia. apply Rbar_le_refl. - rewrite Nat.max_r; try lia. @@ -2116,7 +2117,7 @@ Qed. split; try lia. assert (Rbar_le (f x) (f (Init.Nat.max x N))). { - destruct (le_dec N x). + destruct (Compare_dec.le_dec N x). - rewrite Nat.max_l; try lia. apply Rbar_le_refl. - rewrite Nat.max_r; try lia. @@ -2129,7 +2130,7 @@ Qed. - unfold is_ELimSup_seq, is_sup_seq. split; intros. + destruct (H0 M). - destruct (le_dec x n). + destruct (Compare_dec.le_dec x n). * now apply H1. * assert (n <= x)%nat by lia. apply Rbar_le_lt_trans with (y := f x). diff --git a/coq/utils/RealAdd.v b/rocq/utils/RealAdd.v similarity index 99% rename from coq/utils/RealAdd.v rename to rocq/utils/RealAdd.v index 4ee9811e..9c237488 100644 --- a/coq/utils/RealAdd.v +++ b/rocq/utils/RealAdd.v @@ -6,7 +6,7 @@ Require Import Coq.Reals.Rfunctions. Require Import Coq.Reals.Rprod Coq.Reals.ROrderedType. Require Import Ranalysis_reg. Require Import Coquelicot.Coquelicot. -Require Import Lia Lra. +Require Import ZArith Lia Lra. Require Import Reals.Integration. Require Import Coq.Reals.SeqProp. Require Import Rtrigo_def. @@ -1615,7 +1615,7 @@ Proof. unfold Rpower. unfold ln. match_destr; try lra. - - elimtype False; try tauto; lra. + - exfalso; try tauto; lra. - rewrite Rmult_0_r. now rewrite exp_0. Qed. @@ -1641,7 +1641,7 @@ Lemma Rpower_base_0 e : Rpower 0 e = 1. Proof. unfold Rpower, ln. match_destr. - - elimtype False; try tauto; lra. + - exfalso; try tauto; lra. - rewrite Rmult_0_r. now rewrite exp_0. Qed. @@ -2762,7 +2762,7 @@ Section lim_seq_sup_seq. simpl. split; intros. + destruct H1. - destruct (le_dec x n). + destruct (Compare_dec.le_dec x n). * now apply H1. * assert (n <= x)%nat by lia. apply Rle_lt_trans with (r2 := f x). @@ -2782,7 +2782,7 @@ Section lim_seq_sup_seq. unfold is_sup_seq; simpl; intros. specialize (H0 M). destruct H0 as [N H0]. - destruct (le_dec N n). + destruct (Compare_dec.le_dec N n). + now apply H0. + assert (n <= N)%nat by lia. apply Rle_lt_trans with (r2 := f N). @@ -3371,7 +3371,7 @@ Section Rmax_list. (ex:exists x, In x l /\ P x) : {x | In x l /\ P x}. Proof. induction l; simpl. - - elimtype False. + - exfalso. destruct ex ; intuition. - destruct (dec a). + exists a ; eauto. @@ -4741,7 +4741,7 @@ Proof. destruct (Hx posreal_one). destruct (fin_seq_bounded x x0). exists (Rmax x1 ((Rabs c)+1)); intros. - destruct (lt_dec n x0). + destruct (Compare_dec.lt_dec n x0). - eapply Rle_trans. + apply H0; lia. + apply Rmax_l. @@ -4905,7 +4905,7 @@ Qed. Lemma sum_n_sum_f_clipped (f : nat -> R) (N : nat) : forall (n:nat), (n >= N)%nat -> - sum_n f N = sum_n (fun j => if (le_dec j N) then (f j) else 0) n. + sum_n f N = sum_n (fun j => if (Compare_dec.le_dec j N) then (f j) else 0) n. Proof. intros. replace (n) with (N + (n - N))%nat by lia. @@ -6652,7 +6652,7 @@ Section powerRZ. unfold proj1_sig. match_destr; match_destr. destruct (Z_le_gt_dec x x0); trivial. - elimtype False. + exfalso. assert (x0 <= x - 1)%Z by lia. assert (powerRZ base x0 <= powerRZ base (x-1)%Z). { @@ -6820,7 +6820,7 @@ Qed. Proof. induction l; try now simpl. simpl. - destruct (lt_dec 0 (length l)). + destruct (Compare_dec.lt_dec 0 (length l)). + rewrite prod_f_R0_succ; try assumption. rewrite IHl. rewrite Rmult_assoc. f_equal. rewrite prod_f_R0_pred; try assumption. @@ -6894,7 +6894,7 @@ Qed. prod_f_R0 f n2 = 0. Proof. intros. - destruct (lt_dec n1 n2). + destruct (Compare_dec.lt_dec n1 n2). - rewrite prod_SO_split with (k := n1) (n := n2); trivial. rewrite prod_f_R0_n; trivial. apply Rmult_0_l. diff --git a/coq/utils/RiemannAdd.v b/rocq/utils/RiemannAdd.v similarity index 99% rename from coq/utils/RiemannAdd.v rename to rocq/utils/RiemannAdd.v index f4c40523..bf88c61a 100644 --- a/coq/utils/RiemannAdd.v +++ b/rocq/utils/RiemannAdd.v @@ -1,4 +1,4 @@ -Require Import Reals.Rbase Coq.Reals.RList. +Require Import ZArith Reals.Rbase Coq.Reals.RList. Require Import Reals.Rfunctions. Require Import Ranalysis_reg. Require Import Reals.Integration. diff --git a/coq/utils/StreamAdd.v b/rocq/utils/StreamAdd.v similarity index 98% rename from coq/utils/StreamAdd.v rename to rocq/utils/StreamAdd.v index 68bdd213..a3df94c7 100644 --- a/coq/utils/StreamAdd.v +++ b/rocq/utils/StreamAdd.v @@ -135,7 +135,7 @@ Section Cutting. List.firstn i (firstn j s) = List.firstn j (firstn i s). Proof. repeat rewrite firstn_firstn. - rewrite Min.min_comm. + rewrite PeanoNat.Nat.min_comm. trivial. Qed. diff --git a/coq/utils/StreamLimits.v b/rocq/utils/StreamLimits.v similarity index 100% rename from coq/utils/StreamLimits.v rename to rocq/utils/StreamLimits.v diff --git a/coq/utils/Sums.v b/rocq/utils/Sums.v similarity index 98% rename from coq/utils/Sums.v rename to rocq/utils/Sums.v index d50af400..0b2b3e70 100644 --- a/coq/utils/Sums.v +++ b/rocq/utils/Sums.v @@ -558,9 +558,9 @@ Section inf_sum'. intros n ngt. specialize (H1 n). - cut_to H1; [ | apply (le_trans _ (max N1 N2)); auto with arith]. + cut_to H1; [ | apply (Nat.le_trans _ (max N1 N2)); auto with arith]. specialize (H2 n). - cut_to H2; [ | apply (le_trans _ (max N1 N2)); auto with arith]. + cut_to H2; [ | apply (Nat.le_trans _ (max N1 N2)); auto with arith]. rewrite sum_f_R0'_plus. generalize (R_dist_plus (sum_f_R0' f1 n) sum1 (sum_f_R0' f2 n) sum2); intros. @@ -685,8 +685,8 @@ Section harmonic. intros neq. induction b; simpl. - lia. - - apply gt_n_S in IHb. - eapply le_gt_trans; try eassumption. + - apply Nat.succ_lt_mono in IHb. + eapply Nat.lt_le_trans; try eassumption. apply Sle_mult_gt1; lia. Qed. Lemma sum_f_R0'_eq2 n : @@ -763,10 +763,10 @@ Section harmonic. lra. - rewrite Nat.pow_mul_r. simpl. - rewrite NPeano.Nat.pow_add_r. + rewrite Nat.pow_add_r. unfold ge. replace N with (N * 1)%nat at 1 by lia. - apply mult_le_compat. + apply Nat.mul_le_mono. + generalize (pow_exp_gt 4 N) ; lia. + generalize (Z.to_nat (up l)); intros n. @@ -1004,7 +1004,7 @@ Proof. apply in_seq. split; [lia|]. simpl. - eapply lt_le_trans; try apply mbig. + eapply Nat.lt_le_trans; try apply mbig. unfold M. unfold lt. generalize (list_max_upper (map ginv (seq 0 N))); intros FF. @@ -1034,7 +1034,7 @@ Proof. assert(l2_lower: (forall x, In x l2 -> x >= N)%nat). { intros. - destruct (ge_dec x N); trivial. + destruct (Compare_dec.ge_dec x N); trivial. apply Compare_dec.not_ge in n. assert (inn:In x gpre). { @@ -1047,7 +1047,7 @@ Proof. } pose (nn:=List.list_max l2). destruct (list_max_le l2 nn) as [l2_upper _]. - specialize (l2_upper (le_refl _)). + specialize (l2_upper (Nat.le_refl _)). assert (incl1:incl l2 (seq N (S nn-N))). { intros ? inn. @@ -1099,7 +1099,7 @@ Proof. rewrite map_map. specialize (N2_lt nn (max N1 N2))%nat. - cut_to N2_lt. + cut_to N2_lt; try lia. - unfold R_dist in N2_lt. repeat rewrite sum_f_R0_sum_f_R0' in N2_lt. repeat rewrite sum_f_R0'_list_sum in N2_lt. @@ -1115,11 +1115,7 @@ Proof. apply in_map_iff in H. destruct H as [?[??]]; subst. apply Rabs_pos. - + apply le_plus_minus_r. - lia. - - red. - transitivity N; lia. - - lia. + + lia. Qed. Corollary infinite_sum'_pos_perm (g:nat->nat) (f:nat -> R) l: @@ -1247,7 +1243,7 @@ Qed. Lemma sum_n_m_shift (α : nat -> R) (k n0 : nat) : - sum_n_m α k (n0 + k)%nat = sum_n (fun n1 : nat => α (n1 + k)%nat) n0. + @sum_n_m R_AbelianGroup α k (n0 + k)%nat = @sum_n R_AbelianGroup (fun n1 : nat => α (n1 + k)%nat) n0. Proof. unfold sum_n. induction n0. @@ -1263,7 +1259,7 @@ Qed. Lemma sum_n_m_pos a n1 n2 : (forall n, (n1 <= n <= n2)%nat -> 0 <= a n) -> - 0 <= (sum_n_m a n1 n2). + 0 <= (@sum_n_m R_AbelianGroup a n1 n2). Proof. intros. rewrite sum_n_m_fold_right_seq. @@ -1280,7 +1276,7 @@ Qed. Lemma sum_n_pos_incr a n1 n2 : (forall n, (n1 < n <= n2)%nat -> 0 <= a n) -> - (n1 <= n2)%nat -> sum_n a n1 <= sum_n a n2. + (n1 <= n2)%nat -> @sum_n R_AbelianGroup a n1 <= @sum_n R_AbelianGroup a n2. Proof. intros. destruct (Nat.eq_dec n1 n2); [rewrite e; lra|]. @@ -1349,7 +1345,7 @@ Section Sequences. destruct H. exists (Rmax ((Rabs x)+1) (Rmax_list (map (fun n => Rabs (f n)) (seq 0 x0)))). intros. - destruct (dec_le x0 n). + destruct (Compare_dec.dec_le x0 n). - specialize (H n H1). apply Rle_trans with (r2 := Rabs x + 1). + simpl in H. @@ -1862,7 +1858,7 @@ Qed. rewrite <-H2. simpl. do 2 rewrite <-sum_n_Reals. replace n with (S (pred n)) by lia. - rewrite sum_n_m_sum_n; try lia. + rewrite (@sum_n_m_sum_n R_AbelianGroup); try lia. reflexivity. Qed. @@ -1882,7 +1878,7 @@ Qed. rewrite <-H2. simpl. do 2 rewrite <-sum_n_Reals. replace n with (S (pred n)) by lia. - rewrite sum_n_m_sum_n; try lia. + rewrite (@sum_n_m_sum_n R_AbelianGroup); try lia. reflexivity. Qed. @@ -1894,7 +1890,7 @@ Qed. Series (fun i => f (S m + i)%nat). Proof. intros. - destruct (lt_dec 0 n). + destruct (Compare_dec.lt_dec 0 n). - apply sum_n_m_Series1; trivial; lia. - assert (n=0)%nat by lia. setoid_rewrite H1. @@ -1929,7 +1925,7 @@ Qed. apply Rmax_list_map_seq_lt_gen; try lia. intros. specialize (H0 (N+k)%nat n). - destruct (le_dec (N + k)%nat n). + destruct (Compare_dec.le_dec (N + k)%nat n). - apply H0; lia. - assert (n < N + k)%nat by lia. rewrite sum_n_m_zero; try lia. @@ -1957,7 +1953,7 @@ Qed. apply Rmax_list_map_seq_lt_gen; try lia. intros. specialize (H0 (S (N + k)) (n - 1)%nat). - destruct (lt_dec (n-1)%nat (S (N + k))). + destruct (Compare_dec.lt_dec (n-1)%nat (S (N + k))). - rewrite sum_n_m_zero; try lia. unfold zero; simpl. cbn. @@ -2306,7 +2302,7 @@ Section tails. { intros. unfold s. - destruct (lt_dec n m). + destruct (Compare_dec.lt_dec n m). - unfold sum_n. apply Rge_le. rewrite (sum_n_m_Chasles _ _ n); try lia. @@ -2339,7 +2335,7 @@ Section tails. - unfold Rdiv. rewrite sum_n_m_ext with (b := (fun n : nat => scal (/ s (S N + k)%nat) (gamma n))). + rewrite sum_n_m_scal_l. - rewrite sum_n_m_sum_n; try lia. + rewrite (@sum_n_m_sum_n R_AbelianGroup); try lia. unfold s. unfold scal; simpl. unfold mult; simpl. @@ -2387,7 +2383,8 @@ Section tails. rewrite sum_n_m_ext with (b := fun n => / s n * gamma n) by (intros; unfold Rdiv; now rewrite Rmult_comm). generalize (sum_n_m_sum_n (fun n : nat => / s n * gamma n) N (S N + k)); intros. cut_to H5; try lia. - rewrite H5. + eapply Rle_lt_trans; [right; apply H5 |]. + unfold minus; simpl. unfold plus, opp; simpl. rewrite Rabs_minus_sym in H3. @@ -2396,7 +2393,7 @@ Section tails. unfold minus, plus, opp in H5; simpl in H5. unfold Rminus. replace (S N + k)%nat with (S (N + k))%nat by lia. - rewrite <- H5. + etransitivity; [right; symmetry; apply H5 |]. apply Rle_ge. apply sum_n_m_pos. intros. @@ -2457,7 +2454,7 @@ Section tails. sum_n (fun n0 : nat => X (n0 + S m)%nat) a. Proof. rewrite <-sum_n_m_shift. - rewrite sum_n_m_sum_n; try lia. + rewrite (@sum_n_m_sum_n R_AbelianGroup); try lia. reflexivity. Qed. @@ -2517,7 +2514,7 @@ Section tails. cut_to H2; trivial. - destruct H2 as [rho [? ?]]. assert (0 < 1) by lra. - exists (fun n => if (lt_dec n N) then (mkposreal _ H4) else rho (n - N)%nat). + exists (fun n => if (Compare_dec.lt_dec n N) then (mkposreal _ H4) else rho (n - N)%nat). split. + apply is_lim_seq_incr_n with (N := N). apply is_lim_seq_ext with (u := rho); trivial. diff --git a/coq/utils/Utils.v b/rocq/utils/Utils.v similarity index 100% rename from coq/utils/Utils.v rename to rocq/utils/Utils.v diff --git a/coq/utils/Vector.v b/rocq/utils/Vector.v similarity index 96% rename from coq/utils/Vector.v rename to rocq/utils/Vector.v index 39b2027d..03aa9b70 100644 --- a/coq/utils/Vector.v +++ b/rocq/utils/Vector.v @@ -1,4 +1,4 @@ -Require Import List Lia. +Require Import Arith List Lia. Require Import LibUtils ListAdd BasicUtils. Section Vector. @@ -38,7 +38,7 @@ Section Vector. init | S bound1 => fun pf0 : S bound1 <= m => - let an := vector_fold_right1_bounded_dep f init singleton v bound1 (Le.le_Sn_le bound1 m pf0) in + let an := vector_fold_right1_bounded_dep f init singleton v bound1 (Nat.lt_le_incl bound1 m pf0) in match bound1 as bound1' return (A bound1' -> S bound1' <= m -> A (S bound1')) with | 0 => fun (_ : A 0) (pf1 : 1 <= m) => @@ -58,7 +58,7 @@ Section Vector. - apply f. + exact (v (exist _ n pf)). + apply IHn. - exact (Le.le_Sn_le _ _ pf). + exact (Nat.lt_le_incl _ _ pf). Defined. Definition vnil {T} : Vector T 0. @@ -74,19 +74,19 @@ Section Vector. + exact x. + apply v. exists i. - apply NPeano.Nat.le_neq. + apply Nat.le_neq. split; trivial. now apply le_S_n in pf. Defined. - Definition vhd {T} {n} (v:Vector T (S n)) : T := v (exist _ (0%nat) (NPeano.Nat.lt_0_succ n)). - Definition vlast {T} {n} (v:Vector T (S n)) : T := v (exist _ (n%nat) (NPeano.Nat.lt_succ_diag_r n)). + Definition vhd {T} {n} (v:Vector T (S n)) : T := v (exist _ (0%nat) (Nat.lt_0_succ n)). + Definition vlast {T} {n} (v:Vector T (S n)) : T := v (exist _ (n%nat) (Nat.lt_succ_diag_r n)). Definition vdrop_last {T} {n} (v:Vector T (S n)) : Vector T n. Proof. intros [i pf]; apply v. exists i. - apply NPeano.Nat.lt_lt_succ_r; trivial. + apply Nat.lt_lt_succ_r; trivial. Defined. @@ -120,11 +120,11 @@ Section Vector. Definition vector_fold_right1_dep {A:nat->Type} {B} (f:forall n, B->A n->A (S n)) (init:A 0%nat) (singleton:B->A 1%nat) {m:nat} (v:Vector B m) : A m - := vector_fold_right1_bounded_dep f init singleton v m (Le.le_refl _). + := vector_fold_right1_bounded_dep f init singleton v m (Nat.le_refl _). Definition vector_fold_right_dep {A:nat->Type} {B} (f:forall n, B->A n->A (S n)) (init:A 0%nat) {m:nat} (v:Vector B m) : A m - := vector_fold_right_bounded_dep f init v m (Le.le_refl _). + := vector_fold_right_bounded_dep f init v m (Nat.le_refl _). Definition vector_fold_right1 {A B:Type} (f:B->A->A) (init:A) (singleton:B->A) {m:nat} (v:Vector B m) := vector_fold_right1_dep (A:=fun _ => A) (fun _ => f) init singleton v. @@ -210,7 +210,7 @@ Section Vector. Definition list_fold_right1_dep {A:nat->Type} {B} (f:forall n, B->A n->A (S n)) (init:A 0%nat) (singleton:B->A 1%nat) (l:list B) : A (length l) - := list_fold_right1_bounded_dep f init singleton l (length l) (Le.le_refl _). + := list_fold_right1_bounded_dep f init singleton l (length l) (Nat.le_refl _). Definition list_fold_right_dep {A:nat->Type} {B} (f:forall n, B->A n->A (S n)) (init:A 0%nat) (l:list B) : A (length l) @@ -274,7 +274,7 @@ Section Vector. Proof. intros [i pf]. unfold vcons, vlast, vdrop_last. - destruct (NPeano.Nat.eq_dec i n) + destruct (Nat.eq_dec i n) ; subst ; f_equal ; apply index_pf_irrel. @@ -380,7 +380,7 @@ Section Vector. intros; subst. intros [i pf]. unfold vcons. - destruct (NPeano.Nat.eq_dec i n); simpl; trivial. + destruct (Nat.eq_dec i n); simpl; trivial. Qed. Lemma vdrop_last_proper {T} {n} (x y:Vector T (S n)) : x =v= y -> vdrop_last x =v= vdrop_last y. @@ -417,7 +417,7 @@ Section Vector. rewrite (vector_fold_right_dep_ext _ _ (vector_Sn_split v)). unfold vector_fold_right_dep. simpl. - destruct (NPeano.Nat.eq_dec m m) ; [ | congruence]. + destruct (Nat.eq_dec m m) ; [ | congruence]. f_equal. erewrite vector_fold_right_dep_bounded_pf_ext. erewrite vector_fold_right_dep_bounded_cut_down. @@ -514,7 +514,7 @@ Section Vector. Lemma vector_fold_right1_dep_1 {A:nat->Type} {B} (f:forall n,B->A n->A (S n)) (init:A 0%nat) sing (v:Vector B 1) : - vector_fold_right1_dep f init sing v = sing (v (exist _ 0 NPeano.Nat.lt_0_1)). + vector_fold_right1_dep f init sing v = sing (v (exist _ 0 Nat.lt_0_1)). Proof. unfold vector_fold_right1_dep. simpl. @@ -677,7 +677,7 @@ Section Vector. - intros [[i pf] eqqi]. unfold vector_to_list in *. rewrite vector_fold_right_Sn; simpl. - destruct (NPeano.Nat.eq_dec i n). + destruct (Nat.eq_dec i n). + left. unfold vlast. subst. @@ -697,7 +697,7 @@ Section Vector. unfold vcons. split. - intros [[i pf] eqq]. - destruct (NPeano.Nat.eq_dec i n). + destruct (Nat.eq_dec i n). + subst; eauto. + right. eexists (exist _ i _). @@ -708,10 +708,10 @@ Section Vector. - intros [eqq | inn]. + red. eexists (exist _ n _). - destruct (NPeano.Nat.eq_dec n n); congruence. + destruct (Nat.eq_dec n n); congruence. + destruct inn as [[i pf] eqq]. eexists (exist _ i _). - destruct (NPeano.Nat.eq_dec i n); [lia | ]. + destruct (Nat.eq_dec i n); [lia | ]. erewrite index_pf_irrel; eauto. Unshelve. simpl; lia. @@ -777,7 +777,7 @@ Section Vector. - lia. - rewrite vector_fold_right_dep_Sn. simpl. - destruct (NPeano.Nat.eq_dec i n). + destruct (Nat.eq_dec i n). + subst. unfold vlast. erewrite index_pf_irrel; eauto. @@ -855,7 +855,7 @@ Section Vector. Qed. Definition bounded_seq (start len : nat) : list {n':nat | n' < start+len}%nat - := bounded_seq_bounded start len len (Le.le_refl _). + := bounded_seq_bounded start len len (Nat.le_refl _). Definition bounded_seq0 len : list {n':nat | n' < len}%nat := bounded_seq 0 len. @@ -918,7 +918,7 @@ Section Vector. + rewrite vector_fold_right_Sn. intros [Plast Pdrop]. intros [i pf]. - destruct (NPeano.Nat.eq_dec i n). + destruct (Nat.eq_dec i n). * unfold vlast in Plast. subst. erewrite index_pf_irrel; eauto. @@ -956,7 +956,7 @@ Section Vector. specialize (IHn _ _ eqq2). rewrite eqq1. unfold vcons. - destruct (NPeano.Nat.eq_dec i n); trivial. + destruct (Nat.eq_dec i n); trivial. Qed. Lemma vectoro_to_ovector_forall_some_b {A n} (vo:Vector (option A) n) (v:Vector A n) : @@ -1049,7 +1049,7 @@ Section Vector. - intros eqq. rewrite vector_fold_right_dep_Sn. unfold vlast. - destruct (NPeano.Nat.eq_dec i n). + destruct (Nat.eq_dec i n). + subst. erewrite index_pf_irrel. rewrite eqq; simpl; trivial. @@ -1066,7 +1066,7 @@ Section Vector. intros [i pf2]. apply v. exists i. - eapply NPeano.Nat.lt_le_trans; eassumption. + eapply Nat.lt_le_trans; eassumption. Defined. Lemma vfirstn0 {T} {n} (v:Vector T n) pf : vfirstn v 0 pf = vnil. @@ -1127,10 +1127,10 @@ Section Vector. sing (v (exist (fun n' : nat => (n' < S m)%nat) 0%nat pf1)) | S bound2 => fun (an' : A (S bound2)) (_ : (S (S bound2) <= S m)%nat) => - f (S bound2) (v (exist (fun n' : nat => (n' < S m)%nat) bound (Le.le_Sn_le (S bound) (S m) pf))) an' + f (S bound2) (v (exist (fun n' : nat => (n' < S m)%nat) bound (Nat.lt_le_incl (S bound) (S m) pf))) an' end (vector_fold_right1_bounded_dep f init sing v bound - (Le.le_Sn_le bound (S m) (Le.le_Sn_le (S bound) (S m) pf))) (Le.le_Sn_le (S bound) (S m) pf)) with (vector_fold_right1_bounded_dep f init sing v (S bound) pf2); try eapply IHbound. + (Nat.lt_le_incl bound (S m) (Nat.lt_le_incl (S bound) (S m) pf))) (Nat.lt_le_incl (S bound) (S m) pf)) with (vector_fold_right1_bounded_dep f init sing v (S bound) pf2); try eapply IHbound. clear. destruct bound; simpl. -- erewrite index_pf_irrel; eauto. @@ -1159,7 +1159,7 @@ Section Vector. forall {m:nat} (v:Vector B m), P m v (vector_fold_right1_dep f init sing v). Proof. intros. - rewrite <- (vfirstn_eq v (Le.le_refl m)) at 1. + rewrite <- (vfirstn_eq v (Nat.le_refl m)) at 1. apply vector_fold_right1_bounded_dep_gen_ind; trivial. Qed. @@ -1171,8 +1171,9 @@ Section Vector. Lemma vtake_skip_app_eq_pf n m (pf:(n<=m)%nat) : n + (m - n) = m. Proof. - rewrite NPeano.Nat.add_sub_assoc by trivial. - now rewrite Minus.minus_plus. + rewrite Nat.add_sub_assoc by trivial. + rewrite Nat.add_comm. + apply Nat.add_sub. Defined. Lemma vtake_skip_app_lt_pf {m n i} (pf:(n<=m)%nat) (p2f:i < m) : i < n + (m - n). diff --git a/coq/NeuralNetworks/derivlemmas.v b/rocq/utils/derivlemmas.v similarity index 100% rename from coq/NeuralNetworks/derivlemmas.v rename to rocq/utils/derivlemmas.v diff --git a/coq/utils/improper_integrals.v b/rocq/utils/improper_integrals.v similarity index 100% rename from coq/utils/improper_integrals.v rename to rocq/utils/improper_integrals.v diff --git a/coq/utils/quotient_space.v b/rocq/utils/quotient_space.v similarity index 100% rename from coq/utils/quotient_space.v rename to rocq/utils/quotient_space.v