diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 310fc2d4..a5e5fe79 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -10,12 +10,10 @@ jobs:
runs-on: ubuntu-latest # container actions require GNU/Linux
strategy:
matrix:
- coq_container:
-# - coqorg/coq:8.12.2
-# - coqorg/coq:8.16.1-ocaml-4.13.1-flambda
- - coqorg/coq:8.18.0-ocaml-4.13.1-flambda
+ rocq_container:
+ - rocq/rocq-prover:9.0.1-ocaml-4.14.2-flambda
container:
- image: ${{ matrix.coq_container }}
+ image: ${{ matrix.rocq_container }}
options: --user root
steps:
- uses: actions/checkout@v4
@@ -26,31 +24,8 @@ jobs:
- name: ls
run: ls -la .
- name: Install Opam dependencies
- run: su coq -c 'eval $(opam env) && opam install --deps-only --with-test --with-doc -y -j 2 ./Formal_ML.opam'
+ run: su rocq -c 'eval $(opam env) && opam install --deps-only --with-test --with-doc -y -j 2 ./Formal_ML.opam'
- name: Build using Make
- run: su coq -c 'eval $(opam env) && make -kj 2'
+ run: su rocq -c 'eval $(opam env) && make -kj 2'
- name: Build documentation
- run: su coq -c 'eval $(opam env) && make -kj 2 doc'
-
-# - uses: coq-community/docker-coq-action@v1
-# with:
-# opam_file: 'Formal_ML.opam'
-# coq_version: ${{ matrix.coq_version }}
-# ocaml_version: ${{ matrix.ocaml_version }}
-# # export: 'OPAMWITHTEST OPAMWITHDOC'
-# export: 'OPAMWITHDOC'
-# after_script: |
-# sudo cp -a $(opam config var Formal_ML:build)/documentation .
-# env:
-# OPAMWITHDOC: 'true'
-# OPAMWITHTEST: 'true'
- # - if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- # name: deploy documentation
- # uses: JamesIves/github-pages-deploy-action@3.7.1
- # with:
- # ACCESS_TOKEN: ${{ secrets.ACCESS_TOKEN }}
- # REPOSITORY_NAME: FormalML/FormalML.github.io # the target repository
- # TARGET_FOLDER: main/documentation # target directory
- # BRANCH: main # The branch the action should deploy to.
- # FOLDER: documentation # The folder the action should deploy.
- # CLEAN: true # Automatically remove deleted files from the deploy branch
+ run: su rocq -c 'eval $(opam env) && make -kj 2 doc'
diff --git a/.gitignore b/.gitignore
index 7a410b78..42bb264a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,9 +6,9 @@
*.aux
.coqdeps.d
.#*
-Makefile.coq
-Makefile.coq.conf
-.Makefile.coq.d
+Makefile.rocq
+Makefile.rocq.conf
+.Makefile.rocq.d
extracted
ocaml/_build
bin
diff --git a/Dockerfile b/Dockerfile
index 35a34409..a47e7e2f 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,13 +1,13 @@
-ARG coq_image="coqorg/coq:8.12.2"
-FROM ${coq_image}
+ARG rocq_image="rocq/rocq-prover:9.0.1"
+FROM ${rocq_image}
MAINTAINER Avi Shinnar "shinnar@us.ibm.com"
# needs to be a subdirectory to avoid causing problems with
-# the /home/coq/.opam directory (and probably other stuff)
-WORKDIR /home/coq
+# the /home/rocq/.opam directory (and probably other stuff)
+WORKDIR /home/rocq
-COPY --chown=coq:coq Formal_ML.opam ./formal_ml/
+COPY --chown=rocq:rocq Formal_ML.opam ./formal_ml/
RUN ["/bin/bash", "--login", "-c", "set -x \
&& if [ -n \"${COMPILER_EDGE}\" ]; then opam switch ${COMPILER_EDGE} && eval $(opam env); fi \
@@ -16,13 +16,8 @@ RUN ["/bin/bash", "--login", "-c", "set -x \
&& opam clean -a -c -s --logs"]
-COPY --chown=coq:coq breast-cancer-wisconsin.data breast-cancer-wisconsin.names ./formal_ml/
-COPY --chown=coq:coq _CoqProject Makefile Makefile.coq_modules ./formal_ml/
-COPY --chown=coq:coq coq ./formal_ml/coq
-COPY --chown=coq:coq ocaml ./formal_ml/ocaml
+COPY --chown=rocq:rocq _RocqProject Makefile Makefile.rocq_modules ./formal_ml/
+COPY --chown=rocq:rocq rocq ./formal_ml/rocq
RUN ["/bin/bash", "--login", "-c", "set -x && cd formal_ml && \
make && make doc"]
-
-# CMD ["/bin/bash", "--login", "-c", "set -x && cd formal_ml && \
-# make test"]
\ No newline at end of file
diff --git a/Formal_ML.opam b/Formal_ML.opam
index 8d26190f..0c08b859 100644
--- a/Formal_ML.opam
+++ b/Formal_ML.opam
@@ -9,25 +9,15 @@ homepage: "https://github.com/ibm/formalml"
bug-reports: "https://github.com/ibm/formalml/issues"
depends: [
"ocaml" {>= "4.07.0"}
- "coq" {>= "8.12.1"}
- "coq-mathcomp-ssreflect"
- "coq-mathcomp-algebra"
- "coq-mathcomp-algebra-tactics"
- "coq-mathcomp-real-closed"
- "coq-mathcomp-analysis" {< "1.0.0"}
- "coq-coquelicot" {= "3.3.1" }
- "coq-flocq" {>= "4.0.0" }
- "coq-interval" {>= "4.8.0"}
+ "rocq-core" {>= "9.0.0"}
+ "rocq-stdlib"
+ "rocq-mathcomp-ssreflect"
+ "coq-coquelicot"
"coq-ext-lib" {<= "1.0.0"}
- "ocamlbuild"
- "base64"
- "menhir"
- "csv"
"coq-coq2html" {with-doc}
]
build: [[make]
[make "doc"] {with-doc}
- [make "test"] {with-test}
]
install: [make]
dev-repo: "git+https://github.com/IBM/FormalML.git"
diff --git a/Makefile b/Makefile
index 0f60abbf..a11ded7d 100644
--- a/Makefile
+++ b/Makefile
@@ -1,38 +1,29 @@
-# Contains the list of all the Coq modules
-include Makefile.coq_modules
+# Contains the list of all the Rocq modules
+include Makefile.rocq_modules
-COQ_FILES = $(addprefix coq/,$(MODULES:%=%.v))
+ROCQ_FILES = $(addprefix rocq/,$(MODULES:%=%.v))
-all: coq # ocaml
+all: rocq
-coq: Makefile.coq
- @$(MAKE) -f Makefile.coq
+rocq: Makefile.rocq
+ @$(MAKE) -f Makefile.rocq
-Makefile.coq: Makefile Makefile.coq_modules $(COQ_FILES)
- @coq_makefile -f _CoqProject $(COQ_FILES) -o Makefile.coq
+Makefile.rocq: Makefile Makefile.rocq_modules $(ROCQ_FILES)
+ @rocq makefile -f _RocqProject $(ROCQ_FILES) -o Makefile.rocq
-ocaml: coq
- @$(MAKE) -C ocaml native
+clean-rocq:
+ - @$(MAKE) -f Makefile.rocq clean
-clean-coq:
- - @$(MAKE) -f Makefile.coq clean
-clean-ocaml:
- @$(MAKE) -C ocaml clean
-
-
-COQ_FILES_FOR_DOC = $(MODULES:%=%.v)
+ROCQ_FILES_FOR_DOC = $(MODULES:%=%.v)
GLOB_FILES_FOR_DOC = $(MODULES:%=%.glob)
-doc: coq
+doc: rocq
mkdir -p documentation/html
rm -f documentation/html/*.html
- cd coq && coq2html -d ../documentation/html -base FormalML -external http://coquelicot.saclay.inria.fr/html/ Coquelicot $(COQ_FILES_FOR_DOC) $(GLOB_FILES_FOR_DOC)
-
-test: coq ocaml
- ./bin/nnopt
+ cd rocq && coq2html -d ../documentation/html -base FormalML -external http://coquelicot.saclay.inria.fr/html/ Coquelicot $(ROCQ_FILES_FOR_DOC) $(GLOB_FILES_FOR_DOC)
-clean: clean-coq clean-ocaml
+clean: clean-rocq
rm -rf documentation/html
-.PHONY: all ocaml clean clean-coq coq test doc
+.PHONY: all clean clean-rocq rocq doc
diff --git a/Makefile.coq_modules b/Makefile.rocq_modules
similarity index 84%
rename from Makefile.coq_modules
rename to Makefile.rocq_modules
index 981360eb..d5891582 100644
--- a/Makefile.coq_modules
+++ b/Makefile.rocq_modules
@@ -30,7 +30,6 @@ UTILS = BasicUtils \
ClassicUtils \
CoquelicotAdd \
ELim_Seq \
- ExtrFloatishIEEE \
improper_integrals \
Isomorphism \
ListAdd \
@@ -46,24 +45,11 @@ UTILS = BasicUtils \
StreamAdd \
StreamLimits \
Sums \
- nvector \
Vector \
PushNeg \
DVector \
Utils \
- Floatish/FloatishDef \
- Floatish/FloatishOps \
- Floatish/FloatishRealOps \
- Floatish/FloatishInterval \
- Floatish/FloatishIEEE \
- Floatish/FloatishReal \
- Floatish
-
-NEURAL_NETWORKS = AxiomaticNormedRealVectorSpace \
- DefinedFunctions \
- derivlemmas \
- Gen_NN
-
+ derivlemmas
CERTRL = pmf_monad \
qvalues \
@@ -122,20 +108,10 @@ QLEARN = \
Tsitsiklis \
jaakkola_vector
-FHE = \
- nth_root \
- encode \
- encrypt \
- zp_prim_root \
- arith
-
MODULES = $(addprefix lib_utils/,$(QCERT_LIB_UTILS)) \
$(addprefix CertRL/LM/,$(ELFIC_UTILS)) \
$(addprefix utils/,$(UTILS)) \
- $(addprefix NeuralNetworks/,$(NEURAL_NETWORKS)) \
$(addprefix ProbTheory/,$(PROB_THEORY)) \
$(addprefix CertRL/,$(CERTRL)) \
$(addprefix QLearn/,$(QLEARN)) \
- $(addprefix FHE/,$(FHE)) \
- API
-
+
diff --git a/README.md b/README.md
index ccbb75e9..04a0e7e7 100644
--- a/README.md
+++ b/README.md
@@ -8,15 +8,15 @@
## Getting Started
- To compile the Coq code in this repository,
+ To compile the [Rocq](https://rocq-prover.org/) (previously known as Coq) code in this repository, [Install Rocq](https://rocq-prover.org/install). For example:
- first install opam [opam (ocaml package manager)](https://opam.ocaml.org/).
- - Add support for coq ocaml repositories: `opam repo add coq-released --set-default https://coq.inria.fr/opam/released`.
- - If you want to create a local environment (switch), you can run `opam switch create nnsopt 4.07.0`.
- - Next, run `opam install . --deps-only`. This should install all the dependencies needed, including Coq.
+ - Add support for rocq ocaml repositories: `opam repo add rocq-released https://rocq-prover.org/opam/released`
+ - If you want to create a local environment (switch), you can run `opam switch create formalml 4.14.2`.
+ - Next, run `opam install . --deps-only`. This should install all the dependencies needed, including Rocq.
- Once the prerequisites are installed, run `make` to compile it.
- Alternatively, the included Docker file can be built using Docker to compile the coq code in a suitable environment.
- `docker build --build-arg=coq_image="coqorg/coq:8.8.2" --pull -t nn_sopt .`
+ Alternatively, the included Docker file can be built using Docker to compile the rocq code in a suitable environment.
+ `docker build --pull -t formalml .`
## License
This repository is distributed under the terms of the Apache 2.0 License, see LICENSE.txt.
diff --git a/_CoqProject b/_CoqProject
deleted file mode 100644
index 101bc735..00000000
--- a/_CoqProject
+++ /dev/null
@@ -1 +0,0 @@
--R coq FormalML -arg -set -arg "Warnings=+default,-ambiguous-path,-coercions,-hiding-delimiting-key,-overwriting-delimiting-key,-redundant-canonical-projection,-typechecker,-ssr-search-moved,-deprecated,-notation-overridden"
\ No newline at end of file
diff --git a/_RocqProject b/_RocqProject
new file mode 100644
index 00000000..cf77dfa1
--- /dev/null
+++ b/_RocqProject
@@ -0,0 +1 @@
+-R rocq FormalML -arg -set -arg "Warnings=+default,-ambiguous-path,-coercions,-hiding-delimiting-key,-overwriting-delimiting-key,-redundant-canonical-projection,-typechecker,-ssr-search-moved,-deprecated,-notation-overridden"
\ No newline at end of file
diff --git a/breast-cancer-wisconsin.data b/breast-cancer-wisconsin.data
deleted file mode 100644
index 9e329d00..00000000
--- a/breast-cancer-wisconsin.data
+++ /dev/null
@@ -1,683 +0,0 @@
-1000025,5,1,1,1,2,1,3,1,1,2
-1002945,5,4,4,5,7,10,3,2,1,2
-1015425,3,1,1,1,2,2,3,1,1,2
-1016277,6,8,8,1,3,4,3,7,1,2
-1017023,4,1,1,3,2,1,3,1,1,2
-1017122,8,10,10,8,7,10,9,7,1,4
-1018099,1,1,1,1,2,10,3,1,1,2
-1018561,2,1,2,1,2,1,3,1,1,2
-1033078,2,1,1,1,2,1,1,1,5,2
-1033078,4,2,1,1,2,1,2,1,1,2
-1035283,1,1,1,1,1,1,3,1,1,2
-1036172,2,1,1,1,2,1,2,1,1,2
-1041801,5,3,3,3,2,3,4,4,1,4
-1043999,1,1,1,1,2,3,3,1,1,2
-1044572,8,7,5,10,7,9,5,5,4,4
-1047630,7,4,6,4,6,1,4,3,1,4
-1048672,4,1,1,1,2,1,2,1,1,2
-1049815,4,1,1,1,2,1,3,1,1,2
-1050670,10,7,7,6,4,10,4,1,2,4
-1050718,6,1,1,1,2,1,3,1,1,2
-1054590,7,3,2,10,5,10,5,4,4,4
-1054593,10,5,5,3,6,7,7,10,1,4
-1056784,3,1,1,1,2,1,2,1,1,2
-1059552,1,1,1,1,2,1,3,1,1,2
-1065726,5,2,3,4,2,7,3,6,1,4
-1066373,3,2,1,1,1,1,2,1,1,2
-1066979,5,1,1,1,2,1,2,1,1,2
-1067444,2,1,1,1,2,1,2,1,1,2
-1070935,1,1,3,1,2,1,1,1,1,2
-1070935,3,1,1,1,1,1,2,1,1,2
-1071760,2,1,1,1,2,1,3,1,1,2
-1072179,10,7,7,3,8,5,7,4,3,4
-1074610,2,1,1,2,2,1,3,1,1,2
-1075123,3,1,2,1,2,1,2,1,1,2
-1079304,2,1,1,1,2,1,2,1,1,2
-1080185,10,10,10,8,6,1,8,9,1,4
-1081791,6,2,1,1,1,1,7,1,1,2
-1084584,5,4,4,9,2,10,5,6,1,4
-1091262,2,5,3,3,6,7,7,5,1,4
-1099510,10,4,3,1,3,3,6,5,2,4
-1100524,6,10,10,2,8,10,7,3,3,4
-1102573,5,6,5,6,10,1,3,1,1,4
-1103608,10,10,10,4,8,1,8,10,1,4
-1103722,1,1,1,1,2,1,2,1,2,2
-1105257,3,7,7,4,4,9,4,8,1,4
-1105524,1,1,1,1,2,1,2,1,1,2
-1106095,4,1,1,3,2,1,3,1,1,2
-1106829,7,8,7,2,4,8,3,8,2,4
-1108370,9,5,8,1,2,3,2,1,5,4
-1108449,5,3,3,4,2,4,3,4,1,4
-1110102,10,3,6,2,3,5,4,10,2,4
-1110503,5,5,5,8,10,8,7,3,7,4
-1110524,10,5,5,6,8,8,7,1,1,4
-1111249,10,6,6,3,4,5,3,6,1,4
-1112209,8,10,10,1,3,6,3,9,1,4
-1113038,8,2,4,1,5,1,5,4,4,4
-1113483,5,2,3,1,6,10,5,1,1,4
-1113906,9,5,5,2,2,2,5,1,1,4
-1115282,5,3,5,5,3,3,4,10,1,4
-1115293,1,1,1,1,2,2,2,1,1,2
-1116116,9,10,10,1,10,8,3,3,1,4
-1116132,6,3,4,1,5,2,3,9,1,4
-1116192,1,1,1,1,2,1,2,1,1,2
-1116998,10,4,2,1,3,2,4,3,10,4
-1117152,4,1,1,1,2,1,3,1,1,2
-1118039,5,3,4,1,8,10,4,9,1,4
-1120559,8,3,8,3,4,9,8,9,8,4
-1121732,1,1,1,1,2,1,3,2,1,2
-1121919,5,1,3,1,2,1,2,1,1,2
-1123061,6,10,2,8,10,2,7,8,10,4
-1124651,1,3,3,2,2,1,7,2,1,2
-1125035,9,4,5,10,6,10,4,8,1,4
-1126417,10,6,4,1,3,4,3,2,3,4
-1131294,1,1,2,1,2,2,4,2,1,2
-1132347,1,1,4,1,2,1,2,1,1,2
-1133041,5,3,1,2,2,1,2,1,1,2
-1133136,3,1,1,1,2,3,3,1,1,2
-1136142,2,1,1,1,3,1,2,1,1,2
-1137156,2,2,2,1,1,1,7,1,1,2
-1143978,4,1,1,2,2,1,2,1,1,2
-1143978,5,2,1,1,2,1,3,1,1,2
-1147044,3,1,1,1,2,2,7,1,1,2
-1147699,3,5,7,8,8,9,7,10,7,4
-1147748,5,10,6,1,10,4,4,10,10,4
-1148278,3,3,6,4,5,8,4,4,1,4
-1148873,3,6,6,6,5,10,6,8,3,4
-1152331,4,1,1,1,2,1,3,1,1,2
-1155546,2,1,1,2,3,1,2,1,1,2
-1156272,1,1,1,1,2,1,3,1,1,2
-1156948,3,1,1,2,2,1,1,1,1,2
-1157734,4,1,1,1,2,1,3,1,1,2
-1158247,1,1,1,1,2,1,2,1,1,2
-1160476,2,1,1,1,2,1,3,1,1,2
-1164066,1,1,1,1,2,1,3,1,1,2
-1165297,2,1,1,2,2,1,1,1,1,2
-1165790,5,1,1,1,2,1,3,1,1,2
-1165926,9,6,9,2,10,6,2,9,10,4
-1166630,7,5,6,10,5,10,7,9,4,4
-1166654,10,3,5,1,10,5,3,10,2,4
-1167439,2,3,4,4,2,5,2,5,1,4
-1167471,4,1,2,1,2,1,3,1,1,2
-1168359,8,2,3,1,6,3,7,1,1,4
-1168736,10,10,10,10,10,1,8,8,8,4
-1169049,7,3,4,4,3,3,3,2,7,4
-1170419,10,10,10,8,2,10,4,1,1,4
-1170420,1,6,8,10,8,10,5,7,1,4
-1171710,1,1,1,1,2,1,2,3,1,2
-1171710,6,5,4,4,3,9,7,8,3,4
-1171795,1,3,1,2,2,2,5,3,2,2
-1171845,8,6,4,3,5,9,3,1,1,4
-1172152,10,3,3,10,2,10,7,3,3,4
-1173216,10,10,10,3,10,8,8,1,1,4
-1173235,3,3,2,1,2,3,3,1,1,2
-1173347,1,1,1,1,2,5,1,1,1,2
-1173347,8,3,3,1,2,2,3,2,1,2
-1173509,4,5,5,10,4,10,7,5,8,4
-1173514,1,1,1,1,4,3,1,1,1,2
-1173681,3,2,1,1,2,2,3,1,1,2
-1174057,1,1,2,2,2,1,3,1,1,2
-1174057,4,2,1,1,2,2,3,1,1,2
-1174131,10,10,10,2,10,10,5,3,3,4
-1174428,5,3,5,1,8,10,5,3,1,4
-1175937,5,4,6,7,9,7,8,10,1,4
-1176406,1,1,1,1,2,1,2,1,1,2
-1176881,7,5,3,7,4,10,7,5,5,4
-1177027,3,1,1,1,2,1,3,1,1,2
-1177399,8,3,5,4,5,10,1,6,2,4
-1177512,1,1,1,1,10,1,1,1,1,2
-1178580,5,1,3,1,2,1,2,1,1,2
-1179818,2,1,1,1,2,1,3,1,1,2
-1180194,5,10,8,10,8,10,3,6,3,4
-1180523,3,1,1,1,2,1,2,2,1,2
-1180831,3,1,1,1,3,1,2,1,1,2
-1181356,5,1,1,1,2,2,3,3,1,2
-1182404,4,1,1,1,2,1,2,1,1,2
-1182410,3,1,1,1,2,1,1,1,1,2
-1183240,4,1,2,1,2,1,2,1,1,2
-1183516,3,1,1,1,2,1,1,1,1,2
-1183911,2,1,1,1,2,1,1,1,1,2
-1183983,9,5,5,4,4,5,4,3,3,4
-1184184,1,1,1,1,2,5,1,1,1,2
-1184241,2,1,1,1,2,1,2,1,1,2
-1185609,3,4,5,2,6,8,4,1,1,4
-1185610,1,1,1,1,3,2,2,1,1,2
-1187457,3,1,1,3,8,1,5,8,1,2
-1187805,8,8,7,4,10,10,7,8,7,4
-1188472,1,1,1,1,1,1,3,1,1,2
-1189266,7,2,4,1,6,10,5,4,3,4
-1189286,10,10,8,6,4,5,8,10,1,4
-1190394,4,1,1,1,2,3,1,1,1,2
-1190485,1,1,1,1,2,1,1,1,1,2
-1192325,5,5,5,6,3,10,3,1,1,4
-1193091,1,2,2,1,2,1,2,1,1,2
-1193210,2,1,1,1,2,1,3,1,1,2
-1196295,9,9,10,3,6,10,7,10,6,4
-1196915,10,7,7,4,5,10,5,7,2,4
-1197080,4,1,1,1,2,1,3,2,1,2
-1197270,3,1,1,1,2,1,3,1,1,2
-1197440,1,1,1,2,1,3,1,1,7,2
-1197979,4,1,1,1,2,2,3,2,1,2
-1197993,5,6,7,8,8,10,3,10,3,4
-1198128,10,8,10,10,6,1,3,1,10,4
-1198641,3,1,1,1,2,1,3,1,1,2
-1199219,1,1,1,2,1,1,1,1,1,2
-1199731,3,1,1,1,2,1,1,1,1,2
-1199983,1,1,1,1,2,1,3,1,1,2
-1200772,1,1,1,1,2,1,2,1,1,2
-1200847,6,10,10,10,8,10,10,10,7,4
-1200892,8,6,5,4,3,10,6,1,1,4
-1200952,5,8,7,7,10,10,5,7,1,4
-1201834,2,1,1,1,2,1,3,1,1,2
-1201936,5,10,10,3,8,1,5,10,3,4
-1202125,4,1,1,1,2,1,3,1,1,2
-1202812,5,3,3,3,6,10,3,1,1,4
-1203096,1,1,1,1,1,1,3,1,1,2
-1204242,1,1,1,1,2,1,1,1,1,2
-1204898,6,1,1,1,2,1,3,1,1,2
-1205138,5,8,8,8,5,10,7,8,1,4
-1205579,8,7,6,4,4,10,5,1,1,4
-1206089,2,1,1,1,1,1,3,1,1,2
-1206695,1,5,8,6,5,8,7,10,1,4
-1206841,10,5,6,10,6,10,7,7,10,4
-1207986,5,8,4,10,5,8,9,10,1,4
-1208301,1,2,3,1,2,1,3,1,1,2
-1210963,10,10,10,8,6,8,7,10,1,4
-1211202,7,5,10,10,10,10,4,10,3,4
-1212232,5,1,1,1,2,1,2,1,1,2
-1212251,1,1,1,1,2,1,3,1,1,2
-1212422,3,1,1,1,2,1,3,1,1,2
-1212422,4,1,1,1,2,1,3,1,1,2
-1213375,8,4,4,5,4,7,7,8,2,2
-1213383,5,1,1,4,2,1,3,1,1,2
-1214092,1,1,1,1,2,1,1,1,1,2
-1214556,3,1,1,1,2,1,2,1,1,2
-1214966,9,7,7,5,5,10,7,8,3,4
-1216694,10,8,8,4,10,10,8,1,1,4
-1216947,1,1,1,1,2,1,3,1,1,2
-1217051,5,1,1,1,2,1,3,1,1,2
-1217264,1,1,1,1,2,1,3,1,1,2
-1218105,5,10,10,9,6,10,7,10,5,4
-1218741,10,10,9,3,7,5,3,5,1,4
-1218860,1,1,1,1,1,1,3,1,1,2
-1218860,1,1,1,1,1,1,3,1,1,2
-1219406,5,1,1,1,1,1,3,1,1,2
-1219525,8,10,10,10,5,10,8,10,6,4
-1219859,8,10,8,8,4,8,7,7,1,4
-1220330,1,1,1,1,2,1,3,1,1,2
-1221863,10,10,10,10,7,10,7,10,4,4
-1222047,10,10,10,10,3,10,10,6,1,4
-1222936,8,7,8,7,5,5,5,10,2,4
-1223282,1,1,1,1,2,1,2,1,1,2
-1223426,1,1,1,1,2,1,3,1,1,2
-1223793,6,10,7,7,6,4,8,10,2,4
-1223967,6,1,3,1,2,1,3,1,1,2
-1224329,1,1,1,2,2,1,3,1,1,2
-1225799,10,6,4,3,10,10,9,10,1,4
-1226012,4,1,1,3,1,5,2,1,1,4
-1226612,7,5,6,3,3,8,7,4,1,4
-1227210,10,5,5,6,3,10,7,9,2,4
-1227244,1,1,1,1,2,1,2,1,1,2
-1227481,10,5,7,4,4,10,8,9,1,4
-1228152,8,9,9,5,3,5,7,7,1,4
-1228311,1,1,1,1,1,1,3,1,1,2
-1230175,10,10,10,3,10,10,9,10,1,4
-1230688,7,4,7,4,3,7,7,6,1,4
-1231387,6,8,7,5,6,8,8,9,2,4
-1231706,8,4,6,3,3,1,4,3,1,2
-1232225,10,4,5,5,5,10,4,1,1,4
-1236043,3,3,2,1,3,1,3,6,1,2
-1241559,10,8,8,2,8,10,4,8,10,4
-1241679,9,8,8,5,6,2,4,10,4,4
-1242364,8,10,10,8,6,9,3,10,10,4
-1243256,10,4,3,2,3,10,5,3,2,4
-1270479,5,1,3,3,2,2,2,3,1,2
-1276091,3,1,1,3,1,1,3,1,1,2
-1277018,2,1,1,1,2,1,3,1,1,2
-128059,1,1,1,1,2,5,5,1,1,2
-1285531,1,1,1,1,2,1,3,1,1,2
-1287775,5,1,1,2,2,2,3,1,1,2
-144888,8,10,10,8,5,10,7,8,1,4
-145447,8,4,4,1,2,9,3,3,1,4
-167528,4,1,1,1,2,1,3,6,1,2
-183913,1,2,2,1,2,1,1,1,1,2
-191250,10,4,4,10,2,10,5,3,3,4
-1017023,6,3,3,5,3,10,3,5,3,2
-1100524,6,10,10,2,8,10,7,3,3,4
-1116116,9,10,10,1,10,8,3,3,1,4
-1168736,5,6,6,2,4,10,3,6,1,4
-1182404,3,1,1,1,2,1,1,1,1,2
-1182404,3,1,1,1,2,1,2,1,1,2
-1198641,3,1,1,1,2,1,3,1,1,2
-242970,5,7,7,1,5,8,3,4,1,2
-255644,10,5,8,10,3,10,5,1,3,4
-263538,5,10,10,6,10,10,10,6,5,4
-274137,8,8,9,4,5,10,7,8,1,4
-303213,10,4,4,10,6,10,5,5,1,4
-314428,7,9,4,10,10,3,5,3,3,4
-1182404,5,1,4,1,2,1,3,2,1,2
-1198641,10,10,6,3,3,10,4,3,2,4
-320675,3,3,5,2,3,10,7,1,1,4
-324427,10,8,8,2,3,4,8,7,8,4
-385103,1,1,1,1,2,1,3,1,1,2
-390840,8,4,7,1,3,10,3,9,2,4
-411453,5,1,1,1,2,1,3,1,1,2
-320675,3,3,5,2,3,10,7,1,1,4
-428903,7,2,4,1,3,4,3,3,1,4
-431495,3,1,1,1,2,1,3,2,1,2
-434518,3,1,1,1,2,1,2,1,1,2
-452264,1,1,1,1,2,1,2,1,1,2
-456282,1,1,1,1,2,1,3,1,1,2
-476903,10,5,7,3,3,7,3,3,8,4
-486283,3,1,1,1,2,1,3,1,1,2
-486662,2,1,1,2,2,1,3,1,1,2
-488173,1,4,3,10,4,10,5,6,1,4
-492268,10,4,6,1,2,10,5,3,1,4
-508234,7,4,5,10,2,10,3,8,2,4
-527363,8,10,10,10,8,10,10,7,3,4
-529329,10,10,10,10,10,10,4,10,10,4
-535331,3,1,1,1,3,1,2,1,1,2
-543558,6,1,3,1,4,5,5,10,1,4
-555977,5,6,6,8,6,10,4,10,4,4
-560680,1,1,1,1,2,1,1,1,1,2
-561477,1,1,1,1,2,1,3,1,1,2
-601265,10,4,4,6,2,10,2,3,1,4
-606722,5,5,7,8,6,10,7,4,1,4
-616240,5,3,4,3,4,5,4,7,1,2
-625201,8,2,1,1,5,1,1,1,1,2
-63375,9,1,2,6,4,10,7,7,2,4
-635844,8,4,10,5,4,4,7,10,1,4
-636130,1,1,1,1,2,1,3,1,1,2
-640744,10,10,10,7,9,10,7,10,10,4
-646904,1,1,1,1,2,1,3,1,1,2
-653777,8,3,4,9,3,10,3,3,1,4
-659642,10,8,4,4,4,10,3,10,4,4
-666090,1,1,1,1,2,1,3,1,1,2
-666942,1,1,1,1,2,1,3,1,1,2
-667204,7,8,7,6,4,3,8,8,4,4
-673637,3,1,1,1,2,5,5,1,1,2
-684955,2,1,1,1,3,1,2,1,1,2
-688033,1,1,1,1,2,1,1,1,1,2
-691628,8,6,4,10,10,1,3,5,1,4
-693702,1,1,1,1,2,1,1,1,1,2
-704097,1,1,1,1,1,1,2,1,1,2
-706426,5,5,5,2,5,10,4,3,1,4
-709287,6,8,7,8,6,8,8,9,1,4
-718641,1,1,1,1,5,1,3,1,1,2
-721482,4,4,4,4,6,5,7,3,1,2
-730881,7,6,3,2,5,10,7,4,6,4
-733639,3,1,1,1,2,1,3,1,1,2
-733823,5,4,6,10,2,10,4,1,1,4
-740492,1,1,1,1,2,1,3,1,1,2
-743348,3,2,2,1,2,1,2,3,1,2
-752904,10,1,1,1,2,10,5,4,1,4
-756136,1,1,1,1,2,1,2,1,1,2
-760001,8,10,3,2,6,4,3,10,1,4
-760239,10,4,6,4,5,10,7,1,1,4
-76389,10,4,7,2,2,8,6,1,1,4
-764974,5,1,1,1,2,1,3,1,2,2
-770066,5,2,2,2,2,1,2,2,1,2
-785208,5,4,6,6,4,10,4,3,1,4
-785615,8,6,7,3,3,10,3,4,2,4
-792744,1,1,1,1,2,1,1,1,1,2
-797327,6,5,5,8,4,10,3,4,1,4
-798429,1,1,1,1,2,1,3,1,1,2
-704097,1,1,1,1,1,1,2,1,1,2
-806423,8,5,5,5,2,10,4,3,1,4
-809912,10,3,3,1,2,10,7,6,1,4
-810104,1,1,1,1,2,1,3,1,1,2
-814265,2,1,1,1,2,1,1,1,1,2
-814911,1,1,1,1,2,1,1,1,1,2
-822829,7,6,4,8,10,10,9,5,3,4
-826923,1,1,1,1,2,1,1,1,1,2
-830690,5,2,2,2,3,1,1,3,1,2
-831268,1,1,1,1,1,1,1,3,1,2
-832226,3,4,4,10,5,1,3,3,1,4
-832567,4,2,3,5,3,8,7,6,1,4
-836433,5,1,1,3,2,1,1,1,1,2
-837082,2,1,1,1,2,1,3,1,1,2
-846832,3,4,5,3,7,3,4,6,1,2
-850831,2,7,10,10,7,10,4,9,4,4
-855524,1,1,1,1,2,1,2,1,1,2
-857774,4,1,1,1,3,1,2,2,1,2
-859164,5,3,3,1,3,3,3,3,3,4
-859350,8,10,10,7,10,10,7,3,8,4
-866325,8,10,5,3,8,4,4,10,3,4
-873549,10,3,5,4,3,7,3,5,3,4
-877291,6,10,10,10,10,10,8,10,10,4
-877943,3,10,3,10,6,10,5,1,4,4
-888169,3,2,2,1,4,3,2,1,1,2
-888523,4,4,4,2,2,3,2,1,1,2
-896404,2,1,1,1,2,1,3,1,1,2
-897172,2,1,1,1,2,1,2,1,1,2
-95719,6,10,10,10,8,10,7,10,7,4
-160296,5,8,8,10,5,10,8,10,3,4
-342245,1,1,3,1,2,1,1,1,1,2
-428598,1,1,3,1,1,1,2,1,1,2
-492561,4,3,2,1,3,1,2,1,1,2
-493452,1,1,3,1,2,1,1,1,1,2
-493452,4,1,2,1,2,1,2,1,1,2
-521441,5,1,1,2,2,1,2,1,1,2
-560680,3,1,2,1,2,1,2,1,1,2
-636437,1,1,1,1,2,1,1,1,1,2
-640712,1,1,1,1,2,1,2,1,1,2
-654244,1,1,1,1,1,1,2,1,1,2
-657753,3,1,1,4,3,1,2,2,1,2
-685977,5,3,4,1,4,1,3,1,1,2
-805448,1,1,1,1,2,1,1,1,1,2
-846423,10,6,3,6,4,10,7,8,4,4
-1002504,3,2,2,2,2,1,3,2,1,2
-1022257,2,1,1,1,2,1,1,1,1,2
-1026122,2,1,1,1,2,1,1,1,1,2
-1071084,3,3,2,2,3,1,1,2,3,2
-1080233,7,6,6,3,2,10,7,1,1,4
-1114570,5,3,3,2,3,1,3,1,1,2
-1114570,2,1,1,1,2,1,2,2,1,2
-1116715,5,1,1,1,3,2,2,2,1,2
-1131411,1,1,1,2,2,1,2,1,1,2
-1151734,10,8,7,4,3,10,7,9,1,4
-1156017,3,1,1,1,2,1,2,1,1,2
-1158247,1,1,1,1,1,1,1,1,1,2
-1158405,1,2,3,1,2,1,2,1,1,2
-1168278,3,1,1,1,2,1,2,1,1,2
-1176187,3,1,1,1,2,1,3,1,1,2
-1196263,4,1,1,1,2,1,1,1,1,2
-1196475,3,2,1,1,2,1,2,2,1,2
-1206314,1,2,3,1,2,1,1,1,1,2
-1211265,3,10,8,7,6,9,9,3,8,4
-1213784,3,1,1,1,2,1,1,1,1,2
-1223003,5,3,3,1,2,1,2,1,1,2
-1223306,3,1,1,1,2,4,1,1,1,2
-1223543,1,2,1,3,2,1,1,2,1,2
-1229929,1,1,1,1,2,1,2,1,1,2
-1231853,4,2,2,1,2,1,2,1,1,2
-1234554,1,1,1,1,2,1,2,1,1,2
-1236837,2,3,2,2,2,2,3,1,1,2
-1237674,3,1,2,1,2,1,2,1,1,2
-1238021,1,1,1,1,2,1,2,1,1,2
-1238633,10,10,10,6,8,4,8,5,1,4
-1238915,5,1,2,1,2,1,3,1,1,2
-1238948,8,5,6,2,3,10,6,6,1,4
-1239232,3,3,2,6,3,3,3,5,1,2
-1239347,8,7,8,5,10,10,7,2,1,4
-1239967,1,1,1,1,2,1,2,1,1,2
-1240337,5,2,2,2,2,2,3,2,2,2
-1253505,2,3,1,1,5,1,1,1,1,2
-1255384,3,2,2,3,2,3,3,1,1,2
-1257200,10,10,10,7,10,10,8,2,1,4
-1257648,4,3,3,1,2,1,3,3,1,2
-1257815,5,1,3,1,2,1,2,1,1,2
-1257938,3,1,1,1,2,1,1,1,1,2
-1258549,9,10,10,10,10,10,10,10,1,4
-1258556,5,3,6,1,2,1,1,1,1,2
-1266154,8,7,8,2,4,2,5,10,1,4
-1272039,1,1,1,1,2,1,2,1,1,2
-1276091,2,1,1,1,2,1,2,1,1,2
-1276091,1,3,1,1,2,1,2,2,1,2
-1276091,5,1,1,3,4,1,3,2,1,2
-1277629,5,1,1,1,2,1,2,2,1,2
-1293439,3,2,2,3,2,1,1,1,1,2
-1293439,6,9,7,5,5,8,4,2,1,2
-1294562,10,8,10,1,3,10,5,1,1,4
-1295186,10,10,10,1,6,1,2,8,1,4
-527337,4,1,1,1,2,1,1,1,1,2
-558538,4,1,3,3,2,1,1,1,1,2
-566509,5,1,1,1,2,1,1,1,1,2
-608157,10,4,3,10,4,10,10,1,1,4
-677910,5,2,2,4,2,4,1,1,1,2
-734111,1,1,1,3,2,3,1,1,1,2
-734111,1,1,1,1,2,2,1,1,1,2
-780555,5,1,1,6,3,1,2,1,1,2
-827627,2,1,1,1,2,1,1,1,1,2
-1049837,1,1,1,1,2,1,1,1,1,2
-1058849,5,1,1,1,2,1,1,1,1,2
-1182404,1,1,1,1,1,1,1,1,1,2
-1193544,5,7,9,8,6,10,8,10,1,4
-1201870,4,1,1,3,1,1,2,1,1,2
-1202253,5,1,1,1,2,1,1,1,1,2
-1227081,3,1,1,3,2,1,1,1,1,2
-1230994,4,5,5,8,6,10,10,7,1,4
-1238410,2,3,1,1,3,1,1,1,1,2
-1246562,10,2,2,1,2,6,1,1,2,4
-1257470,10,6,5,8,5,10,8,6,1,4
-1259008,8,8,9,6,6,3,10,10,1,4
-1266124,5,1,2,1,2,1,1,1,1,2
-1267898,5,1,3,1,2,1,1,1,1,2
-1268313,5,1,1,3,2,1,1,1,1,2
-1268804,3,1,1,1,2,5,1,1,1,2
-1276091,6,1,1,3,2,1,1,1,1,2
-1280258,4,1,1,1,2,1,1,2,1,2
-1293966,4,1,1,1,2,1,1,1,1,2
-1296572,10,9,8,7,6,4,7,10,3,4
-1298416,10,6,6,2,4,10,9,7,1,4
-1299596,6,6,6,5,4,10,7,6,2,4
-1105524,4,1,1,1,2,1,1,1,1,2
-1181685,1,1,2,1,2,1,2,1,1,2
-1211594,3,1,1,1,1,1,2,1,1,2
-1238777,6,1,1,3,2,1,1,1,1,2
-1257608,6,1,1,1,1,1,1,1,1,2
-1269574,4,1,1,1,2,1,1,1,1,2
-1277145,5,1,1,1,2,1,1,1,1,2
-1287282,3,1,1,1,2,1,1,1,1,2
-1296025,4,1,2,1,2,1,1,1,1,2
-1296263,4,1,1,1,2,1,1,1,1,2
-1296593,5,2,1,1,2,1,1,1,1,2
-1299161,4,8,7,10,4,10,7,5,1,4
-1301945,5,1,1,1,1,1,1,1,1,2
-1302428,5,3,2,4,2,1,1,1,1,2
-1318169,9,10,10,10,10,5,10,10,10,4
-474162,8,7,8,5,5,10,9,10,1,4
-787451,5,1,2,1,2,1,1,1,1,2
-1002025,1,1,1,3,1,3,1,1,1,2
-1070522,3,1,1,1,1,1,2,1,1,2
-1073960,10,10,10,10,6,10,8,1,5,4
-1076352,3,6,4,10,3,3,3,4,1,4
-1084139,6,3,2,1,3,4,4,1,1,4
-1115293,1,1,1,1,2,1,1,1,1,2
-1119189,5,8,9,4,3,10,7,1,1,4
-1133991,4,1,1,1,1,1,2,1,1,2
-1142706,5,10,10,10,6,10,6,5,2,4
-1155967,5,1,2,10,4,5,2,1,1,2
-1170945,3,1,1,1,1,1,2,1,1,2
-1181567,1,1,1,1,1,1,1,1,1,2
-1182404,4,2,1,1,2,1,1,1,1,2
-1204558,4,1,1,1,2,1,2,1,1,2
-1217952,4,1,1,1,2,1,2,1,1,2
-1224565,6,1,1,1,2,1,3,1,1,2
-1238186,4,1,1,1,2,1,2,1,1,2
-1253917,4,1,1,2,2,1,2,1,1,2
-1265899,4,1,1,1,2,1,3,1,1,2
-1268766,1,1,1,1,2,1,1,1,1,2
-1277268,3,3,1,1,2,1,1,1,1,2
-1286943,8,10,10,10,7,5,4,8,7,4
-1295508,1,1,1,1,2,4,1,1,1,2
-1297327,5,1,1,1,2,1,1,1,1,2
-1297522,2,1,1,1,2,1,1,1,1,2
-1298360,1,1,1,1,2,1,1,1,1,2
-1299924,5,1,1,1,2,1,2,1,1,2
-1299994,5,1,1,1,2,1,1,1,1,2
-1304595,3,1,1,1,1,1,2,1,1,2
-1306282,6,6,7,10,3,10,8,10,2,4
-1313325,4,10,4,7,3,10,9,10,1,4
-1320077,1,1,1,1,1,1,1,1,1,2
-1320077,1,1,1,1,1,1,2,1,1,2
-1320304,3,1,2,2,2,1,1,1,1,2
-1330439,4,7,8,3,4,10,9,1,1,4
-333093,1,1,1,1,3,1,1,1,1,2
-369565,4,1,1,1,3,1,1,1,1,2
-412300,10,4,5,4,3,5,7,3,1,4
-672113,7,5,6,10,4,10,5,3,1,4
-749653,3,1,1,1,2,1,2,1,1,2
-769612,3,1,1,2,2,1,1,1,1,2
-769612,4,1,1,1,2,1,1,1,1,2
-798429,4,1,1,1,2,1,3,1,1,2
-807657,6,1,3,2,2,1,1,1,1,2
-8233704,4,1,1,1,1,1,2,1,1,2
-837480,7,4,4,3,4,10,6,9,1,4
-867392,4,2,2,1,2,1,2,1,1,2
-869828,1,1,1,1,1,1,3,1,1,2
-1043068,3,1,1,1,2,1,2,1,1,2
-1056171,2,1,1,1,2,1,2,1,1,2
-1061990,1,1,3,2,2,1,3,1,1,2
-1113061,5,1,1,1,2,1,3,1,1,2
-1116192,5,1,2,1,2,1,3,1,1,2
-1135090,4,1,1,1,2,1,2,1,1,2
-1145420,6,1,1,1,2,1,2,1,1,2
-1158157,5,1,1,1,2,2,2,1,1,2
-1171578,3,1,1,1,2,1,1,1,1,2
-1174841,5,3,1,1,2,1,1,1,1,2
-1184586,4,1,1,1,2,1,2,1,1,2
-1186936,2,1,3,2,2,1,2,1,1,2
-1197527,5,1,1,1,2,1,2,1,1,2
-1222464,6,10,10,10,4,10,7,10,1,4
-1240603,2,1,1,1,1,1,1,1,1,2
-1240603,3,1,1,1,1,1,1,1,1,2
-1241035,7,8,3,7,4,5,7,8,2,4
-1287971,3,1,1,1,2,1,2,1,1,2
-1289391,1,1,1,1,2,1,3,1,1,2
-1299924,3,2,2,2,2,1,4,2,1,2
-1306339,4,4,2,1,2,5,2,1,2,2
-1313658,3,1,1,1,2,1,1,1,1,2
-1313982,4,3,1,1,2,1,4,8,1,2
-1321264,5,2,2,2,1,1,2,1,1,2
-1321321,5,1,1,3,2,1,1,1,1,2
-1321348,2,1,1,1,2,1,2,1,1,2
-1321931,5,1,1,1,2,1,2,1,1,2
-1321942,5,1,1,1,2,1,3,1,1,2
-1321942,5,1,1,1,2,1,3,1,1,2
-1328331,1,1,1,1,2,1,3,1,1,2
-1328755,3,1,1,1,2,1,2,1,1,2
-1331405,4,1,1,1,2,1,3,2,1,2
-1331412,5,7,10,10,5,10,10,10,1,4
-1333104,3,1,2,1,2,1,3,1,1,2
-1334071,4,1,1,1,2,3,2,1,1,2
-1343068,8,4,4,1,6,10,2,5,2,4
-1343374,10,10,8,10,6,5,10,3,1,4
-1344121,8,10,4,4,8,10,8,2,1,4
-142932,7,6,10,5,3,10,9,10,2,4
-183936,3,1,1,1,2,1,2,1,1,2
-324382,1,1,1,1,2,1,2,1,1,2
-378275,10,9,7,3,4,2,7,7,1,4
-385103,5,1,2,1,2,1,3,1,1,2
-690557,5,1,1,1,2,1,2,1,1,2
-695091,1,1,1,1,2,1,2,1,1,2
-695219,1,1,1,1,2,1,2,1,1,2
-824249,1,1,1,1,2,1,3,1,1,2
-871549,5,1,2,1,2,1,2,1,1,2
-878358,5,7,10,6,5,10,7,5,1,4
-1107684,6,10,5,5,4,10,6,10,1,4
-1115762,3,1,1,1,2,1,1,1,1,2
-1217717,5,1,1,6,3,1,1,1,1,2
-1239420,1,1,1,1,2,1,1,1,1,2
-1254538,8,10,10,10,6,10,10,10,1,4
-1261751,5,1,1,1,2,1,2,2,1,2
-1268275,9,8,8,9,6,3,4,1,1,4
-1272166,5,1,1,1,2,1,1,1,1,2
-1294261,4,10,8,5,4,1,10,1,1,4
-1295529,2,5,7,6,4,10,7,6,1,4
-1298484,10,3,4,5,3,10,4,1,1,4
-1311875,5,1,2,1,2,1,1,1,1,2
-1315506,4,8,6,3,4,10,7,1,1,4
-1320141,5,1,1,1,2,1,2,1,1,2
-1325309,4,1,2,1,2,1,2,1,1,2
-1333063,5,1,3,1,2,1,3,1,1,2
-1333495,3,1,1,1,2,1,2,1,1,2
-1334659,5,2,4,1,1,1,1,1,1,2
-1336798,3,1,1,1,2,1,2,1,1,2
-1344449,1,1,1,1,1,1,2,1,1,2
-1350568,4,1,1,1,2,1,2,1,1,2
-1352663,5,4,6,8,4,1,8,10,1,4
-188336,5,3,2,8,5,10,8,1,2,4
-352431,10,5,10,3,5,8,7,8,3,4
-353098,4,1,1,2,2,1,1,1,1,2
-411453,1,1,1,1,2,1,1,1,1,2
-557583,5,10,10,10,10,10,10,1,1,4
-636375,5,1,1,1,2,1,1,1,1,2
-736150,10,4,3,10,3,10,7,1,2,4
-803531,5,10,10,10,5,2,8,5,1,4
-822829,8,10,10,10,6,10,10,10,10,4
-1016634,2,3,1,1,2,1,2,1,1,2
-1031608,2,1,1,1,1,1,2,1,1,2
-1041043,4,1,3,1,2,1,2,1,1,2
-1042252,3,1,1,1,2,1,2,1,1,2
-1061990,4,1,1,1,2,1,2,1,1,2
-1073836,5,1,1,1,2,1,2,1,1,2
-1083817,3,1,1,1,2,1,2,1,1,2
-1096352,6,3,3,3,3,2,6,1,1,2
-1140597,7,1,2,3,2,1,2,1,1,2
-1149548,1,1,1,1,2,1,1,1,1,2
-1174009,5,1,1,2,1,1,2,1,1,2
-1183596,3,1,3,1,3,4,1,1,1,2
-1190386,4,6,6,5,7,6,7,7,3,4
-1190546,2,1,1,1,2,5,1,1,1,2
-1213273,2,1,1,1,2,1,1,1,1,2
-1218982,4,1,1,1,2,1,1,1,1,2
-1225382,6,2,3,1,2,1,1,1,1,2
-1235807,5,1,1,1,2,1,2,1,1,2
-1238777,1,1,1,1,2,1,1,1,1,2
-1253955,8,7,4,4,5,3,5,10,1,4
-1257366,3,1,1,1,2,1,1,1,1,2
-1260659,3,1,4,1,2,1,1,1,1,2
-1268952,10,10,7,8,7,1,10,10,3,4
-1275807,4,2,4,3,2,2,2,1,1,2
-1277792,4,1,1,1,2,1,1,1,1,2
-1277792,5,1,1,3,2,1,1,1,1,2
-1285722,4,1,1,3,2,1,1,1,1,2
-1288608,3,1,1,1,2,1,2,1,1,2
-1290203,3,1,1,1,2,1,2,1,1,2
-1294413,1,1,1,1,2,1,1,1,1,2
-1299596,2,1,1,1,2,1,1,1,1,2
-1303489,3,1,1,1,2,1,2,1,1,2
-1311033,1,2,2,1,2,1,1,1,1,2
-1311108,1,1,1,3,2,1,1,1,1,2
-1315807,5,10,10,10,10,2,10,10,10,4
-1318671,3,1,1,1,2,1,2,1,1,2
-1319609,3,1,1,2,3,4,1,1,1,2
-1323477,1,2,1,3,2,1,2,1,1,2
-1324572,5,1,1,1,2,1,2,2,1,2
-1324681,4,1,1,1,2,1,2,1,1,2
-1325159,3,1,1,1,2,1,3,1,1,2
-1326892,3,1,1,1,2,1,2,1,1,2
-1330361,5,1,1,1,2,1,2,1,1,2
-1333877,5,4,5,1,8,1,3,6,1,2
-1334015,7,8,8,7,3,10,7,2,3,4
-1334667,1,1,1,1,2,1,1,1,1,2
-1339781,1,1,1,1,2,1,2,1,1,2
-1339781,4,1,1,1,2,1,3,1,1,2
-13454352,1,1,3,1,2,1,2,1,1,2
-1345452,1,1,3,1,2,1,2,1,1,2
-1345593,3,1,1,3,2,1,2,1,1,2
-1347749,1,1,1,1,2,1,1,1,1,2
-1347943,5,2,2,2,2,1,1,1,2,2
-1348851,3,1,1,1,2,1,3,1,1,2
-1350319,5,7,4,1,6,1,7,10,3,4
-1350423,5,10,10,8,5,5,7,10,1,4
-1352848,3,10,7,8,5,8,7,4,1,4
-1353092,3,2,1,2,2,1,3,1,1,2
-1354840,2,1,1,1,2,1,3,1,1,2
-1354840,5,3,2,1,3,1,1,1,1,2
-1355260,1,1,1,1,2,1,2,1,1,2
-1365075,4,1,4,1,2,1,1,1,1,2
-1365328,1,1,2,1,2,1,2,1,1,2
-1368267,5,1,1,1,2,1,1,1,1,2
-1368273,1,1,1,1,2,1,1,1,1,2
-1368882,2,1,1,1,2,1,1,1,1,2
-1369821,10,10,10,10,5,10,10,10,7,4
-1371026,5,10,10,10,4,10,5,6,3,4
-1371920,5,1,1,1,2,1,3,2,1,2
-466906,1,1,1,1,2,1,1,1,1,2
-466906,1,1,1,1,2,1,1,1,1,2
-534555,1,1,1,1,2,1,1,1,1,2
-536708,1,1,1,1,2,1,1,1,1,2
-566346,3,1,1,1,2,1,2,3,1,2
-603148,4,1,1,1,2,1,1,1,1,2
-654546,1,1,1,1,2,1,1,1,8,2
-654546,1,1,1,3,2,1,1,1,1,2
-695091,5,10,10,5,4,5,4,4,1,4
-714039,3,1,1,1,2,1,1,1,1,2
-763235,3,1,1,1,2,1,2,1,2,2
-776715,3,1,1,1,3,2,1,1,1,2
-841769,2,1,1,1,2,1,1,1,1,2
-888820,5,10,10,3,7,3,8,10,2,4
-897471,4,8,6,4,3,4,10,6,1,4
-897471,4,8,8,5,4,5,10,4,1,4
diff --git a/breast-cancer-wisconsin.names b/breast-cancer-wisconsin.names
deleted file mode 100644
index 54b59a12..00000000
--- a/breast-cancer-wisconsin.names
+++ /dev/null
@@ -1,126 +0,0 @@
-Citation Request:
- This breast cancer databases was obtained from the University of Wisconsin
- Hospitals, Madison from Dr. William H. Wolberg. If you publish results
- when using this database, then please include this information in your
- acknowledgements. Also, please cite one or more of:
-
- 1. O. L. Mangasarian and W. H. Wolberg: "Cancer diagnosis via linear
- programming", SIAM News, Volume 23, Number 5, September 1990, pp 1 & 18.
-
- 2. William H. Wolberg and O.L. Mangasarian: "Multisurface method of
- pattern separation for medical diagnosis applied to breast cytology",
- Proceedings of the National Academy of Sciences, U.S.A., Volume 87,
- December 1990, pp 9193-9196.
-
- 3. O. L. Mangasarian, R. Setiono, and W.H. Wolberg: "Pattern recognition
- via linear programming: Theory and application to medical diagnosis",
- in: "Large-scale numerical optimization", Thomas F. Coleman and Yuying
- Li, editors, SIAM Publications, Philadelphia 1990, pp 22-30.
-
- 4. K. P. Bennett & O. L. Mangasarian: "Robust linear programming
- discrimination of two linearly inseparable sets", Optimization Methods
- and Software 1, 1992, 23-34 (Gordon & Breach Science Publishers).
-
-1. Title: Wisconsin Breast Cancer Database (January 8, 1991)
-
-2. Sources:
- -- Dr. WIlliam H. Wolberg (physician)
- University of Wisconsin Hospitals
- Madison, Wisconsin
- USA
- -- Donor: Olvi Mangasarian (mangasarian@cs.wisc.edu)
- Received by David W. Aha (aha@cs.jhu.edu)
- -- Date: 15 July 1992
-
-3. Past Usage:
-
- Attributes 2 through 10 have been used to represent instances.
- Each instance has one of 2 possible classes: benign or malignant.
-
- 1. Wolberg,~W.~H., \& Mangasarian,~O.~L. (1990). Multisurface method of
- pattern separation for medical diagnosis applied to breast cytology. In
- {\it Proceedings of the National Academy of Sciences}, {\it 87},
- 9193--9196.
- -- Size of data set: only 369 instances (at that point in time)
- -- Collected classification results: 1 trial only
- -- Two pairs of parallel hyperplanes were found to be consistent with
- 50% of the data
- -- Accuracy on remaining 50% of dataset: 93.5%
- -- Three pairs of parallel hyperplanes were found to be consistent with
- 67% of data
- -- Accuracy on remaining 33% of dataset: 95.9%
-
- 2. Zhang,~J. (1992). Selecting typical instances in instance-based
- learning. In {\it Proceedings of the Ninth International Machine
- Learning Conference} (pp. 470--479). Aberdeen, Scotland: Morgan
- Kaufmann.
- -- Size of data set: only 369 instances (at that point in time)
- -- Applied 4 instance-based learning algorithms
- -- Collected classification results averaged over 10 trials
- -- Best accuracy result:
- -- 1-nearest neighbor: 93.7%
- -- trained on 200 instances, tested on the other 169
- -- Also of interest:
- -- Using only typical instances: 92.2% (storing only 23.1 instances)
- -- trained on 200 instances, tested on the other 169
-
-4. Relevant Information:
-
- Samples arrive periodically as Dr. Wolberg reports his clinical cases.
- The database therefore reflects this chronological grouping of the data.
- This grouping information appears immediately below, having been removed
- from the data itself:
-
- Group 1: 367 instances (January 1989)
- Group 2: 70 instances (October 1989)
- Group 3: 31 instances (February 1990)
- Group 4: 17 instances (April 1990)
- Group 5: 48 instances (August 1990)
- Group 6: 49 instances (Updated January 1991)
- Group 7: 31 instances (June 1991)
- Group 8: 86 instances (November 1991)
- -----------------------------------------
- Total: 699 points (as of the donated datbase on 15 July 1992)
-
- Note that the results summarized above in Past Usage refer to a dataset
- of size 369, while Group 1 has only 367 instances. This is because it
- originally contained 369 instances; 2 were removed. The following
- statements summarizes changes to the original Group 1's set of data:
-
- ##### Group 1 : 367 points: 200B 167M (January 1989)
- ##### Revised Jan 10, 1991: Replaced zero bare nuclei in 1080185 & 1187805
- ##### Revised Nov 22,1991: Removed 765878,4,5,9,7,10,10,10,3,8,1 no record
- ##### : Removed 484201,2,7,8,8,4,3,10,3,4,1 zero epithelial
- ##### : Changed 0 to 1 in field 6 of sample 1219406
- ##### : Changed 0 to 1 in field 8 of following sample:
- ##### : 1182404,2,3,1,1,1,2,0,1,1,1
-
-5. Number of Instances: 699 (as of 15 July 1992)
-
-6. Number of Attributes: 10 plus the class attribute
-
-7. Attribute Information: (class attribute has been moved to last column)
-
- # Attribute Domain
- -- -----------------------------------------
- 1. Sample code number id number
- 2. Clump Thickness 1 - 10
- 3. Uniformity of Cell Size 1 - 10
- 4. Uniformity of Cell Shape 1 - 10
- 5. Marginal Adhesion 1 - 10
- 6. Single Epithelial Cell Size 1 - 10
- 7. Bare Nuclei 1 - 10
- 8. Bland Chromatin 1 - 10
- 9. Normal Nucleoli 1 - 10
- 10. Mitoses 1 - 10
- 11. Class: (2 for benign, 4 for malignant)
-
-8. Missing attribute values: 16
-
- There are 16 instances in Groups 1 to 6 that contain a single missing
- (i.e., unavailable) attribute value, now denoted by "?".
-
-9. Class distribution:
-
- Benign: 458 (65.5%)
- Malignant: 241 (34.5%)
diff --git a/coq/API.v b/coq/API.v
deleted file mode 100644
index ac550a47..00000000
--- a/coq/API.v
+++ /dev/null
@@ -1,58 +0,0 @@
-Require Import FloatishIEEE.
-Require Import ExtrFloatishIEEE.
-
-
-(* Require Import ExtrR. *)
-(* Our stuff modules *)
-
-Require Import Utils.
-Require Import Vector.
-Require Gen_NN.
-Require Import DefinedFunctions.
-Require Import FloatishDef.
-Require Import BinInt.
-Require Import String.
-Require Import Streams.
-Local Open Scope list.
-
-Existing Instance floatish_IEEE.
-
-Example test :=
- mk_env_entry (Name "f", DTfloat) (FfromZ 1)%Z ::
- mk_env_entry (Name "v", DTVector 3) (ConstVector 3 ((FfromZ (-2)))%Z) ::
- mk_env_entry (Name "m", DTMatrix 2 3) (ConstMatrix 2 3 (FfromZ 3))%Z :: nil.
-Module API.
- Example opt := @Gen_NN.opt floatish_IEEE.
- Example opt2 := @Gen_NN.opt2 floatish_IEEE.
- Example test_update := @Gen_NN.test_update floatish_IEEE.
- Example testopt := @Gen_NN.testopt floatish_IEEE.
- Example testreeopt := @Gen_NN.testreeopt floatish_IEEE.
- Example gradenv := @Gen_NN.gradenv floatish_IEEE.
- Example gradenv_tree := @Gen_NN.gradenv_tree floatish_IEEE.
- Example test_env := test.
-
- Example discard_first {A} (l:list (list A)) : list (list A) := List.map (@List.tl A) l.
- Definition normalizeIntData := Gen_NN.normalizeIntData.
- Definition init_env2 := Gen_NN.init_env2.
- CoFixpoint mkIndexedStream {A} (i : nat) (ran : nat -> A) : Stream A :=
- Cons (ran i) (mkIndexedStream (S i) ran).
- Definition streamtake := Gen_NN.streamtake.
- Definition df_env := DefinedFunctions.df_env.
- Definition eval_wisconsin_batch (nsamp:nat)
- (env:df_env) (data : Matrix float nsamp 10) : list float :=
- match Gen_NN.eval_wisconsin_batch nsamp env data with
- | Some val => val :: nil
- | _ => nil
- end.
-
- Definition wisconsin_test := Gen_NN.wisconsin_test.
- Definition wisconsin_test_env := Gen_NN.wisconsin_test_env.
- Definition wisconsin_gradenv_tree := Gen_NN.wisconsin_gradenv_tree.
- Definition wisconsin_gradenv := Gen_NN.wisconsin_gradenv.
- Definition nn_test := Gen_NN.NN_test.
- Definition nn_test_val := Gen_NN.NN_test_val.
- Definition nn_test_env := Gen_NN.NN_test_env.
- Definition nn_test_gradenv_tree := Gen_NN.NN_test_gradenv_tree.
- Definition nn_test_gradenv := Gen_NN.NN_test_gradenv.
-
- End API.
diff --git a/coq/FHE/arith.v b/coq/FHE/arith.v
deleted file mode 100644
index 913793d2..00000000
--- a/coq/FHE/arith.v
+++ /dev/null
@@ -1,3455 +0,0 @@
-Require Import Reals Lra Lia List Permutation.
-From mathcomp Require Import common ssreflect fintype bigop ssrnat matrix Rstruct complex seq fingroup.
-From mathcomp Require Import ssralg ssrfun.
-From mathcomp Require Import generic_quotient ring_quotient.
-From mathcomp Require Import poly mxpoly polydiv ssrint zmodp eqtype ssrbool div order.
-From mathcomp Require Import ring ssrZ.
-
-Import ssralg.GRing.
-Import ssrnum.Num.Theory.
-Require Import nth_root encode zp_prim_root.
-
-Ltac coq_lra := lra.
-From mathcomp Require Import lra.
-
-Set Bullet Behavior "Strict Subproofs".
-
-
-Local Open Scope ring_scope.
-
-
-Record ENCODE : Type :=
- mkENCODE
- { clear : {poly int} ;
- scale : nat
- }.
-
-Record FHE : Type :=
- mkFHE
- { q : nat;
- cypher : {poly {poly 'Z_q}} ;
- norm_bound : R ;
- noise_bound : R
- }.
-
-Definition FHE_add {q : nat} (P Q : {poly {poly 'Z_q}} ) := P + Q.
-
-Definition FHE_mult_base {q : nat} (P Q : {poly {poly 'Z_q}} ) := P * Q.
-
-Definition zliftc {q : nat} (c : 'Z_q) : int :=
- if (c <= q/2) then c%:Z else c%:Z - q%:Z.
-
-Lemma zliftc_bound {q : nat} (c : 'Z_q) :
- 1 < q ->
- `| zliftc c | <= q/2.
-Proof.
- intros.
- rewrite /zliftc.
- case: (boolP (c <= q/2)); intros.
- - by destruct c.
- - destruct c.
- simpl in *.
- move /negP in i.
- assert (m > (Nat.divmod q 1 0 1).1) by lia.
- rewrite Zp_cast // in i0.
- replace `|m-q| with (q - m)%N by (rewrite distnEr; lia).
- assert (q <= 2*(Nat.divmod q 1 0 1).1 + 1).
- {
- move: (Nat.divmod_spec q 1 0 1 (le_refl _)) => /=.
- case: (Nat.divmod q 1 0 1) => /= x u.
- rewrite !Nat.add_0_r.
- move=> [-> ubound].
- destruct u; lia.
- }
- lia.
- Qed.
-
-Lemma modp_small (q : nat) (m : nat) :
- m < (Zp_trunc q).+2 ->
- nat_of_ord (intmul (1 : 'Z_q) (Posz m)) = m.
-Proof.
- rewrite /intmul Zp_nat /=.
- by apply/modn_small.
-Qed.
-
-Lemma modpp (q : nat) :
- nat_of_ord (intmul (1 : 'Z_q) (Posz (Zp_trunc q).+2)) = 0%nat.
-Proof.
- by rewrite /intmul Zp_nat /= modnn.
-Qed.
-
-Lemma modpp' (q : nat) :
- 1 < q ->
- intmul (1 : 'Z_q) (Posz q) = 0.
-Proof.
- intros.
- apply ord_inj.
- by rewrite /intmul Zp_nat /= Zp_cast // modnn.
-Qed.
-
-Lemma zliftc_valid {q : nat} (c : 'Z_q) :
- 1 < q ->
- (zliftc c) %:~R = c.
-Proof.
- intros.
- rewrite /zliftc.
- case: (c <= q/2).
- - destruct c.
- apply ord_inj => /=.
- by rewrite modp_small.
- - destruct c.
- rewrite intrD mulrNz modpp' // oppr0 addr0.
- apply ord_inj => /=.
- by rewrite modp_small.
-Qed.
-
-Lemma zliftc_add2 {q : nat} (a b : 'Z_q) :
- 1 < q ->
- `|zliftc (a + b)%R - ((zliftc a) + (zliftc b))%R | <= q.
-Proof.
- have diveq: (Nat.divmod q 1 0 1).1 = (q / 2)%nat by [].
-
- move=> qbig.
- rewrite /zliftc /=.
- Ltac t1 C :=
- match goal with
- | [|- is_true (leq (absz ?x) _) ] =>
- have ->: x = C by lia
- end.
-
- Ltac t2 C := t1 C
- ; case: (boolP (((Zp_trunc q).+1 < _ + _))) => // _
- ; (rewrite ?mul0n ?mul1n ?Zp_cast // ?distnn); lia.
-
- case: (boolP ( (a + b) %% (Zp_trunc q).+2 <= (Nat.divmod q 1 0 1).1)) => ltab
- ; case: (boolP (a <= (Nat.divmod q 1 0 1).1)) => lta
- ; case: (boolP (b <= (Nat.divmod q 1 0 1).1)) => ltb
- ; rewrite modnD // [modn a _]modn_small ?[modn b _]modn_small; try apply ltn_ord.
- - t2 (opp (Posz (muln (leq (S (S (Zp_trunc q))) (addn a b)) (S (S (Zp_trunc q)))))).
- - t2 (add (V:=int_ZmodType) q (opp (Posz (muln (leq (S (S (Zp_trunc q))) (addn a b)) (S (S (Zp_trunc q))))))).
- - t2 (add (V:=int_ZmodType) (Posz q) (opp (Posz (muln (leq (S (S (Zp_trunc q))) (addn a b)) (S (S (Zp_trunc q))))))).
- - have eqq: ((Zp_trunc q).+1 < a + b).
- { move: (Nat.divmod_spec q 1 0 1) lta ltb.
- case: (Nat.divmod q 1 0 1) => /= x u.
- move/(_ (le_n _)) => [eqq1 eqq2] lta ltb.
- rewrite !Nat.add_0_r in eqq1.
- rewrite {1}Zp_cast //.
- destruct u; simpl in *; lia.
- }
- rewrite eqq mul1n.
- rewrite {1}Zp_cast // in eqq.
- rewrite {3}Zp_cast //.
- by t1 (Posz q).
- - case: (boolP ((Zp_trunc q).+1 < a + b)).
- + rewrite mul1n {1}Zp_cast // => ineq2.
- simpl in lta.
- rewrite diveq in lta ltb ltab.
- case: (eqVneq q (nat_of_ord a + nat_of_ord b))%nat =>eqq2.
- * case/negP: ltab.
- by rewrite {3}Zp_cast // {3}eqq2 modnn.
- * suff: 2 * (q / 2)%nat <= q by lia.
- apply /leP.
- by apply Nat.mul_div_le.
- + rewrite mul0n => _.
- t1 (opp (Posz q)).
- by rewrite abszN.
- - t2 (opp (Posz (muln (leq (S (S (Zp_trunc q))) (addn a b)) (S (S (Zp_trunc q)))))).
- - t2 (opp (Posz (muln (leq (S (S (Zp_trunc q))) (addn a b)) (S (S (Zp_trunc q)))))).
- - have eqq: ((Zp_trunc q).+1 < a + b).
- { move: (Nat.divmod_spec q 1 0 1) lta ltb.
- case: (Nat.divmod q 1 0 1) => /= x u.
- move/(_ (le_n _)) => [eqq1 eqq2] lta ltb.
- rewrite !Nat.add_0_r in eqq1.
- rewrite {1}Zp_cast //.
- destruct u; simpl in *; lia.
- }
- rewrite eqq mul1n.
- rewrite {1}Zp_cast // in eqq.
- rewrite {3}Zp_cast //; lia.
-Qed.
-
-
-Lemma bounded_dvdn_cases (a q : nat) :
- 1 < q ->
- (q %| a)%N ->
- a < 2 * q ->
- (a == q) || (a == 0%nat).
-Proof.
- move=> qbig.
- move/dvdnP => [k ->].
- rewrite ltn_mul2r.
- move/andP=>[_ -].
- case: k; [| case]; lia.
-Qed.
-
-
-Lemma bounded_dvdn (a q : nat) :
- 1 < q ->
- (q %| a)%N ->
- a < 2 * q ->
- { c : nat |
- a = (c * q)%N /\ c <= 1}.
-Proof.
- move=> qbig qdiva asmall.
- move: (bounded_dvdn_cases a q qbig qdiva asmall).
- case: (eqVneq a q) => /=.
- - exists 1%nat. lia.
- - exists 0%nat. lia.
-Qed.
-
-Lemma bounded_divi (a : int) (q : nat) :
- 1 < q ->
- (q %| `|a|)%N ->
- `|a| < 2 * q ->
- { c : nat |
- `| a | = (c * q)%N /\ c <= 1}.
-Proof.
- by apply bounded_dvdn.
-Qed.
-
-Lemma bounded_divn_alt (a : nat) (q k : nat) :
- 1 < q ->
- (q %| a)%N ->
- a <= k * q ->
- { c : nat | a = (c * q)%N /\ c <= k}.
-Proof.
- move=> qbig.
- exists (a %/ q)%nat.
- by rewrite divnK // leq_divLR.
-Qed.
-
-Lemma bounded_divi_alt (a : int) (q k : nat) :
- 1 < q ->
- (q %| `|a|)%N ->
- `|a| <= k * q ->
- { c : nat | `| a | = (c * q)%N /\ c <= k}.
-Proof.
- apply bounded_divn_alt.
-Qed.
-
-Lemma absz_triang (a b : int) :
- `|a + b| <= `|a| + `|b|.
-Proof.
- lia.
-Qed.
-
-Lemma Zp_int (p:nat) (a : int) (pbig:1
inZp n
- | Negz n => inZp (p - modn (n.+1) p)
- end.
-Proof.
- destruct a.
- - by rewrite /intmul Zp_nat.
- - rewrite /intmul.
- rewrite /intmul Zp_nat.
- rewrite /opp /= /Zp_opp /=.
- apply ord_inj => /=.
- by rewrite Zp_cast //.
-Qed.
-
-Lemma Z_q_0_dvd_abs {q : nat} (a : int) :
- 1 < q ->
- a %:~R = (0 : 'Z_q) ->
- (q %| `|a|)%N.
-Proof.
- move => qbig.
- rewrite Zp_int //.
- case: a => n /=.
- - move/(f_equal val) => /=.
- rewrite Zp_cast //.
- by move/eqP.
- - move/(f_equal val) => /=.
- rewrite Zp_cast //.
- rewrite modnB; try lia.
- rewrite modnn modn_mod.
- lia.
- Qed.
-
-Lemma zliftc_add2_ex {q : nat} (a b : 'Z_q) :
- 1 < q ->
- { c : nat |
- `|zliftc (a + b)%R - ((zliftc a) + (zliftc b))%R | = (c * q)%N /\
- c <= 1}.
-Proof.
- intros.
- apply bounded_divi; trivial.
- - apply Z_q_0_dvd_abs; trivial.
- rewrite rmorphD rmorphN !rmorphD /=.
- rewrite !zliftc_valid //.
- ring.
- - move: (zliftc_bound a H) => a_bound.
- move: (zliftc_bound b H) => b_bound.
- move: (zliftc_bound (a + b) H) => ab_bound.
- move: (absz_triang (zliftc a) (zliftc b)) => triang_ab.
- move: (absz_triang (zliftc (a+b)) (- (zliftc a + zliftc b))) => triang_abc.
- rewrite -abszN in triang_ab.
- assert (q/2 + q/2 + q/2 < 2 * q)%N.
- {
- move: (Nat.divmod_spec q 1 0 1 (le_refl _)) => /=.
- case: Nat.divmod => /= x y.
- rewrite !Nat.add_0_r.
- move=> [qeq ubound]; subst.
- destruct y; lia.
- }
- lia.
-Qed.
-
-Lemma zliftc_mul2 {q : nat} (a b : 'Z_q) :
- 1 < q ->
- `|(zliftc (a * b) - (zliftc a) * (zliftc b))%R | <= (q/2 * (1 + q/2))%N.
-Proof.
- intros.
- move: (zliftc_bound a H) => a_bound.
- move: (zliftc_bound b H) => b_bound.
- move: (zliftc_bound (a * b) H) => ab_bound.
- move: (absz_triang (zliftc (a*b)) (- (zliftc a * zliftc b))) => triang.
- rewrite abszN in triang.
- assert (`|zliftc a * zliftc b| <= (q/2) * (q/2)).
- {
- rewrite abszM.
- by apply leq_mul.
- }
- lia.
- Qed.
-
-Lemma zliftc_mul2_ex {q : nat} (a b : 'Z_q) :
- 1 < q ->
- { c : nat |
- `|zliftc (a * b)%R - ((zliftc a) * (zliftc b))%R | = (c * q)%N /\
- c <= q/2}.
-Proof.
- intros.
- apply bounded_divi_alt; trivial.
- - apply Z_q_0_dvd_abs; trivial.
- rewrite rmorphD rmorphN rmorphM /=.
- rewrite !zliftc_valid //.
- ring.
- - generalize (zliftc_mul2 a b H); intros.
- assert ((q / 2) * (1 + q / 2) <= (q/2) * q).
- {
- rewrite -Nat.div2_div leq_pmul2l.
- - generalize (Nat.lt_div2 q); lia.
- - generalize (div2_not_R0 q); lia.
- }
- lia.
-Qed.
-
-Definition zlift {q : nat} (a : {poly 'Z_q}) : {poly int} :=
- map_poly zliftc a.
-
-Lemma zliftc0 (q : nat) :
- zliftc (0 : 'Z_q) = 0.
-Proof.
- by rewrite /zliftc.
-Qed.
-
-Lemma zlift0 {q : nat} (a : {poly 'Z_q}) :
- a = 0 ->
- zlift a = 0.
-Proof.
- intros.
- rewrite /zlift /zliftc H.
- apply map_poly0.
-Qed.
-
-Lemma zlift0_alt {q : nat} :
- zlift (0 : {poly 'Z_q}) = 0.
-Proof.
- rewrite /zlift /zliftc.
- apply map_poly0.
-Qed.
-
-
-Definition icoef_maxnorm_ord (p : {poly int}):nat := \max_(j < seq.size p) `|p`_ j|.
-Definition icoef_maxnorm_nat (p : {poly int}):nat := \max_(0 <= j < seq.size p) `|p`_ j|.
-Definition icoef_maxnorm (p : {poly int}):nat := \max_(j < seq.size p) `|p`_ j|.
-
-Lemma icoef_maxnorm_conv (p : {poly int}) :
- \max_(j < seq.size p) `|p`_ j| = \max_(0 <= j < seq.size p) `|p`_ j|.
-Proof.
- by rewrite big_mkord.
-Qed.
-
-Lemma zlift_add2_ex {q : nat} (a b : {poly 'Z_q}) :
- (1 < q) ->
- { c : {poly int} |
- zlift (a + b) = zlift a + zlift b + c /\
- icoef_maxnorm c <= q}.
-Proof.
- exists (zlift (a + b) - (zlift a + zlift b)).
- split.
- - ring.
- - rewrite /icoef_maxnorm /zlift.
- apply /bigmax_leqP => i _.
- rewrite !(coefD,coefN).
- rewrite !coef_map_id0; try apply zliftc0.
- rewrite !coefD; try apply zlift0_alt.
- by apply zliftc_add2.
-Qed.
-
-Definition zlift_add2_perturb {q : nat} (qbig:(1 < q)) (a b : {poly 'Z_q}) : {poly int}
- := sval (zlift_add2_ex a b qbig).
-
-Lemma zlift_add2_eq {q : nat} (qbig:(1 < q)) (a b : {poly 'Z_q})
- : zlift (a + b) = zlift a + zlift b + zlift_add2_perturb qbig a b.
-Proof.
- rewrite /zlift_add2_perturb.
- case: (zlift_add2_perturb qbig a b).
- case: zlift_add2_ex; intros; simpl.
- tauto.
-Qed.
-
-Lemma zlift_add2_perturb_small {q : nat} (qbig:(1 < q)) (a b : {poly 'Z_q})
- : icoef_maxnorm (zlift_add2_perturb qbig a b) <= q.
-Proof.
- rewrite /zlift_add2_perturb.
- case: (zlift_add2_perturb qbig a b).
- case: zlift_add2_ex; intros; simpl.
- tauto.
-Qed.
-
-Definition upi (c:R) : int := int_of_Z (up c).
-
-(* 0 <= rand < 1 *)
-Definition ran_round (x rand : R) : int :=
- let hi := upi x in
- if (hi%:~R - x < rand)%O then hi else (hi - 1).
-
-Definition nearest_round (x : R) : int := ran_round x (1/2)%R.
-
-Definition nearest_round_sgn (x : R) : int :=
- if (x < 0)%O then - (nearest_round (-x)) else nearest_round x.
-
-Definition nearest_round_int (n d : int) : int := nearest_round ((n %:~R)/(d %:~R))%R.
-Definition nearest_round_int_sgn (n d : int) : int := nearest_round_sgn ((n %:~R)/(d %:~R))%R.
-
-Lemma inv_nat_pos (d : nat) :
- d != 0%N ->
- (inv d%:~R > (0 : R))%O.
-Proof.
- rewrite invr_gt0.
- replace (zero (ssrnum.Num.NumDomain.porder_zmodType R_numDomainType)) with
- (((0 : int)%:~R : R)) by lra.
- rewrite ltr_int; lia.
-Qed.
-
-Lemma div_nat_pos (n d : nat) :
- d != 0%N ->
- ((0 : R) <= n%:~R / d %:~R)%O.
-Proof.
- intros.
- generalize (inv_nat_pos d H); intros.
- generalize (ler_pM2l H0 (0 : R) (n %:~R)); intros.
- rewrite mulr0 mulrC in H1.
- rewrite H1.
- replace (0 : R) with (((0 : int)%:~R : R) ) by lra.
- rewrite ler_int; lia.
-Qed.
-
-Lemma div_int_pos (d : nat) (n : int):
- d != 0%N ->
- ((0 : int) <= n)%O ->
- ((0 : R) <= n%:~R / d %:~R)%O.
-Proof.
- intros.
- destruct n.
- - by apply div_nat_pos.
- - lia.
-Qed.
-
-Lemma div_negz_neg (n d : nat) :
- d != 0%N ->
- (((Negz n)%:~R : R)/ d%:~R < 0)%O.
-Proof.
- intros.
- generalize (inv_nat_pos d H); intros.
- generalize (ltr_pM2l H0 ((Negz n) %:~R) (0 : R)); intros.
- rewrite mulr0 mulrC in H1.
- rewrite H1.
- replace (0 : R) with (((0 : int)%:~R : R) ) by lra.
- rewrite ltr_int; lia.
-Qed.
-
-Lemma div_int_neg (d : nat) (n : int):
- d != 0%N ->
- ((0 : int) > n)%O ->
- ((0 : R) > ((n%:~R / d %:~R) : R))%O.
-Proof.
- intros.
- destruct n.
- - lia.
- - by apply div_negz_neg.
-Qed.
-
-Lemma div_int_neg' (d : nat) (n : int):
- d != 0%N ->
- ((0 : int) >= n)%O ->
- ((0 : R) >= ((n%:~R / d %:~R) : R))%O.
-Proof.
- intros.
- destruct n.
- - assert (n = 0%N) by lia.
- rewrite H1 mul0r.
- by rewrite Order.POrderTheory.lexx.
- - generalize (div_negz_neg n d H); intros.
- lra.
-Qed.
-
-
-Lemma IZRE (n : Z) : IZR n = (int_of_Z n)%:~R.
-Proof.
- destruct n.
- - by rewrite /intmul /= /IZR R00.
- - by rewrite /IZR -INR_IPR INRE /int_of_Z /intmul.
- - rewrite /IZR -INR_IPR INRE /int_of_Z /intmul /opp /=.
- f_equal; f_equal.
- lia.
-Qed.
-
-Lemma IZREb (n : int) : n%:~R = IZR (ssrZ.Z_of_int n).
-Proof.
- by rewrite -{1}(Z_of_intK n) -IZRE.
-Qed.
-
-Lemma up_int_add (n : Z) (c : R) :
- up (Rplus (IZR n) c) = Zplus n (up c).
-Proof.
- symmetry.
- destruct (archimed c).
- apply tech_up; rewrite plus_IZR; coq_lra.
-Qed.
-
-Lemma upi_intl (n : int) (c : R) :
- upi ((n%:~R) + c) = n + upi c.
-Proof.
- rewrite /upi !IZREb up_int_add.
- lia.
-Qed.
-
-Lemma upi0 :
- upi 0 = 1.
-Proof.
- rewrite /upi -(tech_up 0 1); try coq_lra.
- lia.
-Qed.
-
-Lemma nearest_round_int0 (d : int) :
- nearest_round_int 0 d = 0.
-Proof.
- rewrite /nearest_round_int /nearest_round /ran_round.
- rewrite mul0r upi0 oppr0 addr0 addrN.
- rewrite /intmul /=.
- rewrite ltr_pdivlMr; last by lra.
- by rewrite mul1r /natmul/= gtrDl ltr10.
-Qed.
-
-Lemma nearest_round_int_add (n1 : int) (c : R) :
- nearest_round (n1 %:~R + c)%R = n1 + nearest_round c.
-Proof.
- rewrite /nearest_round /ran_round.
- have ->: (upi (n1%:~R + c))%:~R - (n1%:~R + c)%R = ((upi c)%:~R - c).
- {
- rewrite upi_intl; lra.
- }
- case: Order.TotalTheory.ltP => _ ; by rewrite upi_intl; lra.
-Qed.
-
-Lemma nearest_round_int_mul_add (n1 n2 d : int) :
- d <> 0 ->
- nearest_round_int (n1 * d + n2) d = n1 + nearest_round_int n2 d.
-Proof.
- intros.
- rewrite /nearest_round_int -nearest_round_int_add.
- rewrite intrD intrM mulrDl.
- rewrite -mulrA divff.
- - f_equal; lra.
- - rewrite intr_eq0.
- by apply/eqP.
-Qed.
-
-Lemma nearest_round_int_mul_add_r (n1 n2 d : int) :
- d <> 0 ->
- nearest_round_int (n1 + d * n2) d = nearest_round_int n1 d + n2.
-Proof.
- intros.
- rewrite mulrC addrC nearest_round_int_mul_add //.
- lia.
-Qed.
-
-Lemma nearest_round_int_mod (n1 n2 p1 p2 : int) :
- p1 <> 0 ->
- intdiv.modz n1 (p1 * p2) = intdiv.modz n2 (p1 * p2) ->
- intdiv.modz (nearest_round_int n1 p1) p2 = intdiv.modz (nearest_round_int n2 p1) p2.
-Proof.
- intros.
- assert (exists (h : int),
- n1 = n2 + h * (p1 * p2)).
- {
- exists (intdiv.divz n1 (p1 * p2) - intdiv.divz n2 (p1 * p2)).
- generalize (intdiv.divz_eq n1 (p1 * p2)); intros.
- generalize (intdiv.divz_eq n2 (p1 * p2)); intros.
- rewrite {1}H2 {1}H1 H0.
- ring.
- }
- destruct H1.
- rewrite H1 addrC (mulrC p1 p2) mulrA nearest_round_int_mul_add //.
- by rewrite intdiv.modzMDl.
-Qed.
-
-Lemma up_add2' (r1 r2 : R) :
- (up(r1 + r2) = Z.add (up r1) (up r2) \/ up(r1 + r2) = Z.sub (Z.add (up r1) (up r2)) 1)%Z.
-Proof.
- destruct (archimed r1).
- destruct (archimed r2).
- destruct (archimed (r1 + r2)).
- case: (boolP (Rleb ((IZR (up r1) + IZR (up r2)) - (r1 + r2)) 1)).
- - intros.
- left.
- move /RlebP in p.
- by rewrite (tech_up (r1 + r2) (up r1 + up r2)) // plus_IZR; coq_lra.
- - intros.
- right.
- move /RlebP in i.
- by rewrite (tech_up (r1 + r2) (up r1 + up r2 - 1)) // !plus_IZR opp_IZR; coq_lra.
-Qed.
-
-Lemma up_add2 (r1 r2 : R) :
- Z.abs_nat (Z.sub (up (r1 + r2)) (Z.add (up r1) (up r2))) <= 1.
-Proof.
- destruct (up_add2' r1 r2); rewrite H; lia.
-Qed.
-
-Lemma up_le (r1 r2 : R) :
- Rle r1 r2 ->
- Z.le (up r1) (up r2).
-Proof.
- case: (boolP (Rleb (IZR (up r1)) (IZR(up r2)))).
- - intros.
- apply le_IZR.
- by move /RlebP in p.
- - intros.
- move /RlebP in i.
- assert (Rlt (IZR (up r2)) (IZR (up r1))).
- {
- coq_lra.
- }
- apply lt_IZR in H0.
- assert (Z.le (Z.add (up r2) (1 : Z)) (up r1)).
- {
- lia.
- }
- apply IZR_le in H1.
- rewrite /one /= plus_IZR in H1.
- destruct (archimed r1).
- destruct (archimed r2).
- coq_lra.
-Qed.
-
-Lemma upi_le (r1 r2 : R) :
- (r1 <= r2)%O ->
- (upi r1 <= upi r2)%O.
-Proof.
- intros.
- move /RlebP in H.
- apply up_le in H.
- rewrite /upi.
- lia.
-Qed.
-
-Lemma int_of_Z_abs x : `|int_of_Z x| = Z.abs_nat x.
-Proof.
- rewrite /int_of_Z /absz /Z.abs_nat.
- destruct x; trivial.
- lia.
-Qed.
-
-Lemma upi_add2' (n1 n2 : R) :
- (upi (n1 + n2) = (upi n1 + upi n2))%R \/
- (upi (n1 + n2) = (upi n1 + upi n2) -1)%R.
-Proof.
- rewrite /upi.
- destruct (up_add2' n1 n2).
- - left.
- by rewrite H raddfD /=.
- - right.
- rewrite H.
- rewrite -Z.add_opp_r !raddfD /=.
- lra.
-Qed.
-
-Lemma upi_add2 (n1 n2 : R) :
- `|upi (n1 + n2) - (upi n1 + upi n2)%R| <= 1.
-Proof.
- rewrite /upi.
- rewrite (_:
- int_of_Z (up (n1 + n2)) - (int_of_Z (up n1) + int_of_Z (up n2))%R =
- int_of_Z (Z.sub (up (n1 + n2)) (Z.add (up n1) (up n2)))%Z).
- - rewrite int_of_Z_abs.
- apply up_add2.
- - rewrite -Z.add_opp_r !raddfD /= raddfN /= raddfD /=.
- lra.
-Qed.
-
-Lemma ran_round_add2 (n1 n2 cutoff : R) :
- ((0 : R) < cutoff)%O ->
- (cutoff < (1 : R))%O ->
- let sum := ran_round n1 cutoff + ran_round n2 cutoff in
- `|ran_round (n1 + n2)%R cutoff - sum| <= 1.
-Proof.
- move=> cutoff_big cutoff_small.
- rewrite /ran_round.
- case: Order.TotalTheory.ltP=>lt1 ; case: Order.TotalTheory.ltP => lt2 ; case: Order.TotalTheory.ltP => lt3
- ; try (destruct (upi_add2' n1 n2); rewrite H; try lia; rewrite H raddfD /= in lt1; lra).
-Qed.
-
-Lemma nearest_round_add2 (n1 n2 : R) :
- let sum := nearest_round n1 + nearest_round n2 in
- `|nearest_round (n1 + n2) - sum| <= 1.
-Proof.
- rewrite /nearest_round.
- apply ran_round_add2; lra.
-Qed.
-
-Lemma nearest_round_int_add2 (n1 n2 d : int) :
- d <> 0 ->
- let sum := nearest_round_int n1 d + nearest_round_int n2 d in
- `|nearest_round_int (n1 + n2) d - sum| <= 1.
-Proof.
- move=> dn0.
- rewrite /= /nearest_round_int.
- rewrite (_:((n1 + n2)%:~R / d%:~R)%R = ((n1%:~R / d%:~R) + (n2%:~R / d%:~R))%R).
- - apply nearest_round_add2.
- - ring.
-Qed.
-
-Lemma upi_bound (r : R) :
- let diff := (upi r)%:~R - r in
- Rlt 0%R diff /\ Rle diff 1%R.
-Proof.
- rewrite /upi -IZRE.
- apply for_base_fp.
-Qed.
-
-Lemma upi_bound_O (r : R) :
- let diff := (upi r)%:~R - r in
- ((0 : R) < diff)%O /\ (diff <= 1%R)%O.
-Proof.
- destruct (upi_bound r).
- move /RltP in H.
- by move /RleP in H0.
-Qed.
-
-Lemma int_to_R_lt :
- {mono (intmul (1 : R)) : x y / (x < y)%O}.
-Proof.
- move=> x y.
- apply ltr_int.
-Qed.
-
-Lemma int_to_R_le :
- {mono (intr: int -> R) : x y / (x <= y)%O}.
-Proof.
- move => x y.
- apply ler_int.
-Qed.
-
-Lemma upi_nat_mul_bound_R (r : R) (n : nat) :
- (((upi r)%:~R *+n - r*+n) <= n%:R)%O.
-Proof.
- destruct (upi_bound_O r).
- assert ((0 : R) <= n%:R)%O.
- {
- by rewrite (ler_int _ 0 n).
- }
- apply (ler_wpM2r H1) in H0.
- rewrite mul1r mulrBl in H0.
- by replace ((upi r)%:~R *+ n - r *+ n ) with
- ((upi r)%:~R * n%:R - r * n%:R) by ring.
-Qed.
-
-Lemma Rabs_left_O (r : R) :
- (r <= 0)%O -> Rabs r = -r.
-Proof.
- intros.
- move /RleP in H.
- rewrite /opp /=.
- rewrite /zero /= in H.
- case: (boolP (r == 0)); intros.
- - rewrite (eqP p) Rabs_R0 /zero /=.
- coq_lra.
- - move /eqP in i.
- apply Rabs_left.
- rewrite /zero /= in i.
- coq_lra.
-Qed.
-
-Lemma nearest_round_bound_O (r : R) :
- let diff := (nearest_round r)%:~R - r in
- (Rabs diff <= 1/2)%O.
-Proof.
- rewrite /nearest_round /ran_round.
- destruct (upi_bound_O r).
- case: (boolP (Order.lt _ _)); intros.
- - rewrite Rabs_right.
- + by apply Order.POrderTheory.ltW.
- + apply Rle_ge.
- left.
- apply /RltP.
- by rewrite /zero /= in H.
- - rewrite Order.TotalTheory.ltNge in i.
- apply negbNE in i.
- rewrite Rabs_left_O; lra.
-Qed.
-
-
-Lemma nearest_round_bound_O' (r : R) :
- let diff := (nearest_round r)%:~R - r in
- Rabs diff <> 1/2 ->
- (Rabs diff < 1/2)%O.
-Proof.
- generalize (nearest_round_bound_O r); intros.
- rewrite Order.POrderTheory.lt_neqAle.
- unfold diff in *.
- move /eqP in H0.
- by apply /andP.
-Qed.
-
-Lemma Rabs_n (r : R) (n : nat) :
- (Rabs r) *+n = Rabs (r *+ n).
-Proof.
- rewrite -mulr_natr -(mulr_natr r n) RnormM.
- rewrite /mul /=.
- f_equal.
- rewrite Rabs_right //.
- apply Rle_ge.
- apply /RleP.
- replace (IZR Z0) with (0%:R : R).
- - rewrite ler_nat.
- lia.
- - by rewrite mulr0n /zero /=.
- Qed.
-
-Lemma nearest_round_nat_mul_bound_R (r : R) (n : nat) :
- (Rabs (add ((nearest_round r)%:~R *+n) (opp r*+n)) <= (1/2)*+n)%O.
-Proof.
- generalize (nearest_round_bound_O r); intros.
- apply (ler_wMn2r (R := R_numDomainType) n) in H.
- apply /RleP.
- move /RleP in H.
- eapply Rle_trans; [|apply H].
- right.
- rewrite Rabs_n.
- f_equal.
- ring.
-Qed.
-
-Lemma nearest_round_nat_mul_bound_R'_S (r : R) (n : nat) :
- Rabs ((nearest_round r)%:~R - r) <> 1/2 ->
- (Rabs (add ((nearest_round r)%:~R *+n.+1) (opp r*+n.+1)) < (1/2)*+n.+1)%O.
-Proof.
- intros.
- generalize (nearest_round_bound_O' r H); intros.
- apply (ltr_wMn2r (R := R_numDomainType) n.+1) in H0.
- apply /RltP.
- assert (0 < n.+1) by lia.
- rewrite H1 in H0.
- move /RltP in H0.
- eapply Rle_lt_trans; [|apply H0].
- right.
- rewrite Rabs_n.
- f_equal.
- ring.
-Qed.
-
-Lemma upi_nat_mul_bound_R0_lt (r : R) (n : nat) :
- ((0 : R) < ((upi r)%:~R *+n.+1 - r*+n.+1))%O.
-Proof.
- destruct (upi_bound_O r).
- assert (0 < n.+1) by lia.
- apply (ltr_wpMn2r H1 (R := R_numDomainType)) in H.
- rewrite mul0rn in H.
- by replace ((upi r)%:~R *+ n.+1 - r *+ n.+1 ) with
- (((upi r)%:~R - r) *+ n.+1) by ring.
-Qed.
-
-Lemma upi_nat_mul_bound_R0_lt_alt (r : R) (n : nat) :
- (r *+ n.+1 < (upi r)%:~R *+n.+1)%O.
-Proof.
- generalize (upi_nat_mul_bound_R0_lt r n); intros.
- lra.
-Qed.
-
-Lemma upi_nat_mul_bound_R0_lt_alt_alt (r : R) (n : nat) :
- (r *+ n.+1 < ((upi r)*+n.+1)%:~R)%O.
-Proof.
- generalize (upi_nat_mul_bound_R0_lt_alt r n); intros.
- by replace ((upi r *+ n.+1)%:~R)%R with ((((upi r)%:~R) : R) *+ n.+1) by ring.
-Qed.
-
-Lemma up_lt_le (r : R) (i : Z) :
- (Rlt r (IZR i)) ->
- Z.le (up r) i.
-Proof.
- intros.
- destruct (archimed r); intros.
- assert (Rle (IZR (up r) - 1) r) by coq_lra.
- assert (IZR (up r) - 1 = IZR ((up r) - 1)).
- {
- rewrite /opp /add /one /=.
- rewrite minus_IZR.
- coq_lra.
- }
- assert (~(Z.lt i (up r))).
- {
- intros ?.
- assert (Z.le i ((up r) - 1)) by lia.
- assert (Rle (IZR i) (IZR ((up r) -1))).
- {
- by apply IZR_le.
- }
- rewrite -H3 in H6.
- assert (Rle (IZR i) r).
- {
- eapply Rle_trans.
- apply H6.
- apply H2.
- }
- coq_lra.
- }
- lia.
-Qed.
-
-Lemma upi_lt_le (r : R) (i : int) :
- (r < i%:~R)%O ->
- (upi r <= i)%O.
-Proof.
- intros.
- generalize (up_lt_le r (Z_of_int i)); intros.
- rewrite -IZREb in H0.
- move /RltP in H.
- specialize (H0 H).
- rewrite /upi.
- lia.
-Qed.
-
-Lemma upi_nat_mul_bound_R0_le (r : R) (n : nat) :
- ((0 : R) <= ((upi r)%:~R *+n - r*+n))%O.
-Proof.
- destruct n.
- - lra.
- - apply Order.POrderTheory.ltW.
- apply upi_nat_mul_bound_R0_lt.
-Qed.
-
-Lemma upi_nat_mul_le (r : R) (n : nat) :
- (upi (r *+ n.+1) <= upi(r)*+n.+1)%O.
-Proof.
- apply upi_lt_le.
- apply upi_nat_mul_bound_R0_lt_alt_alt.
-Qed.
-
-Definition upiR (r: R) : R := (upi r) %:~R.
-
-Lemma eq_rel2 {A B:Type} {RR:A->B->bool} {a: A} {b:B} {c:A} {d:B} :
- RR a b ->
- a = c ->
- b = d ->
- RR c d.
-Proof.
- congruence.
-Qed.
-
-Lemma upi_nat_mul' (r : R) (n : nat) :
- (upi(r)*+n.+1 - upi (r *+ n.+1) < n.+1%:R)%O.
-Proof.
- destruct (upi_bound_O (r *+ n.+1)).
- generalize (upi_nat_mul_bound_R r n.+1); intros.
- assert ((upiR r) *+ n.+1 - (upiR (r*+n.+1)) < (n.+1%:R))%O.
- {
- replace ((upiR r) *+ n.+1 - upiR (r *+ n.+1)) with
- ( (upiR r) *+ n.+1 - r*+n.+1 - (upiR (r *+ n.+1) - r*+n.+1)) by ring.
- eapply Order.POrderTheory.lt_le_trans; [| apply H1].
- rewrite /upiR.
-
- have: (forall (a b : R), ((0 : R) < b)%O -> (a - b < a)%O) by (intros; lra).
- by apply.
- }
- rewrite /upiR in H2.
- rewrite -int_to_R_lt.
- apply (eq_rel2 H2); ring.
-Qed.
-
-Lemma upi_nat_mul'' (r : R) (n : nat) :
- ((0:int) <= upi(r)*+n.+1 - upi (r *+ n.+1) < n.+1%:R)%O.
-Proof.
- apply /andP.
- split.
- - generalize (upi_nat_mul_le r n); intros.
- lia.
- - apply upi_nat_mul'.
-Qed.
-
-Lemma upi_nat_mul (r : R) (n : nat) :
- (upi(r)*+n - upi (r *+ n) < n%:R)%O.
-Proof.
- destruct n.
- - by rewrite !mulr0n upi0 add0r.
- - apply upi_nat_mul'.
-Qed.
-
-Lemma upi_nat_mul_pos (r : R) (n : nat) :
- ((0 : int) <= upi(r)*+n.+1 - upi (r *+ n.+1))%O.
-Proof.
- generalize (upi_nat_mul_le r n); intros.
- lia.
-Qed.
-
-Lemma upi_nat_mul_abs_S (r : R) (n : nat) :
- `|upi (r *+ n.+1) - upi(r)*+n.+1| < n.+1.
-Proof.
- rewrite distnC.
- assert (Posz `|upi r *+ n.+1 - upi (r *+ n.+1)| < Posz n.+1)%O.
- {
- rewrite gez0_abs; [| apply (upi_nat_mul_pos r n)].
- generalize (upi_nat_mul' r n); intros.
- lia.
- }
- lia.
-Qed.
-
-Lemma upi_nat_mul_abs (r : R) (n : nat) :
- `|upi (r *+ n) - upi(r)*+n| <= n.+1.
-Proof.
- destruct n.
- - by rewrite !mulr0n upi0 addr0.
- - generalize (upi_nat_mul_abs_S r n).
- lia.
-Qed.
-
-Lemma Rabs_opp_sym (x y : R) :
- Rabs (add x (opp y)) = Rabs (add y (opp x)).
-Proof.
- generalize (Rabs_minus_sym x y); intros.
- rewrite /add /opp /=.
- by rewrite /Rminus in H.
-Qed.
-
-Lemma Rplus_add (x y : R) :
- Rplus x y = add x y.
-Proof.
- by rewrite /add /=.
-Qed.
-
-Lemma Ropp_opp(x : R) :
- Ropp x = opp x.
-Proof.
- by rewrite /opp /=.
-Qed.
-
-Lemma Rmult_mul (x y : R) :
- Rmult x y = mul x y.
-Proof.
- by rewrite /mul /=.
-Qed.
-
-Lemma Rabs_absz (j : int) :
- Rabs (j %:~R) = (`|j|%:~R).
-Proof.
- rewrite /absz.
- destruct j.
- - rewrite Rabs_right //.
- apply Rle_ge.
- apply /RleP.
- replace (IZR Z0) with ((0 %:~R):R).
- + rewrite int_to_R_le; lia.
- + by rewrite /zero /=.
- - rewrite NegzE Rabs_left.
- + rewrite Ropp_opp.
- by rewrite mulrNz opprK.
- + apply /RltP.
- replace (IZR Z0) with ((0 %:~R):R).
- * rewrite int_to_R_lt; lia.
- * by rewrite /zero /=.
-Qed.
-
-Lemma half_lt (j : int) (n : nat) :
- (Posz (n/2)%N < j)%O ->
- (Posz n < 2 * j)%O.
-Proof.
- intros.
- generalize (Nat.div_mod_eq n 2); intros.
- case: (boolP (n mod 2 == 0%N)); intros; try lia.
- generalize (Nat.mod_upper_bound n 2); intros.
- lia.
-Qed.
-
-Lemma half_int_to_R_le (j : int) (n : nat):
- ((j %:~R : R) <= (1/2 : R) *+ n)%O ->
- (j <= (n / 2)%N)%O.
-Proof.
- intros.
- case: (boolP (Order.le _ _)); intros; trivial.
- rewrite -Order.TotalTheory.ltNge in i.
- assert ((n : int) < 2*j)%O.
- {
- by apply half_lt.
- }
- rewrite -(@ltr_int R_numDomainType) in H0.
- assert ( (n%:~R : R) / 2 < j%:~R)%O by lra.
- rewrite -(@ltr_int R_numDomainType) in i.
- assert ( (1 / 2 : R) *+ n < j%:~R)%O.
- {
- assert ((1 / 2 : R) *+ n = (n%:~R : R) / 2) by ring.
- by rewrite H2.
- }
- lra.
-Qed.
-
-Lemma half_int_to_R_le_alt (j : int) (n : nat):
- ((j %:~R : R) <= (1/2 : R) *+ n)%O ->
- (j <= (n ./2)%N)%O.
-Proof.
- intros.
- case: (boolP (Order.le _ _)); intros; trivial.
- assert ((n : int) < 2*j)%O by lia.
- rewrite -(@ltr_int R_numDomainType) in H0.
- assert ( (1 / 2 : R) *+ n < j%:~R)%O.
- {
- assert ((1 / 2 : R) *+ n = (n%:~R : R) / 2) by ring.
- lra.
- }
- generalize (Order.POrderTheory.le_lt_trans H H1).
- by rewrite Order.POrderTheory.ltxx.
-Qed.
-
-Lemma half_int_to_R_lt (j : int) (n : nat):
- ((j %:~R : R) < (1/2 : R) *+ n.+1)%O ->
- (j <= (n./2)%N)%O.
-Proof.
- intros.
- case: (boolP (odd n)); intros.
- - assert (n.+1 = (2 * (n.+1 %/ 2))%N).
- {
- generalize (divn_eq n.+1 2); intros.
- assert (n.+1 %% 2 = 0)%N.
- {
- rewrite modn2 oddS p //.
- }
- rewrite H1 addn0 in H0.
- by rewrite {1}H0 mulnC.
- }
- rewrite H0 in H.
- assert (1 / 2 *+ (2 * (n.+1 %/ 2)) = ((Posz (n.+1%/2))%:~R : R)) by (by field).
- rewrite H1 ltr_int in H.
- lia.
- - apply Order.POrderTheory.ltW in H.
- generalize (half_int_to_R_le_alt _ _ H); intros.
- lia.
-Qed.
-
-Lemma Rabs_absz_half (j : int) (n : nat) :
- (Rabs (j %:~R) <= (1/2 : R) *+ n)%O ->
- (`|j| <= (n/2)%N)%O.
-Proof.
- intros.
- rewrite Rabs_absz in H.
- by apply half_int_to_R_le in H.
-Qed.
-
-Lemma Rabs_absz_half_le (j : int) (n : nat) :
- (Rabs (j %:~R) <= (1/2 : R) *+ n)%O ->
- (`|j| <= (n./2)%N)%O.
-Proof.
- intros.
- rewrite Rabs_absz in H.
- by apply half_int_to_R_le_alt in H.
-Qed.
-
-Lemma Rabs_absz_half_lt (j : int) (n : nat) :
- (Rabs (j %:~R) < (1/2 : R) *+ n.+1)%O ->
- (`|j| <= (n./2)%N)%O.
-Proof.
- intros.
- rewrite Rabs_absz in H.
- by apply half_int_to_R_lt in H.
-Qed.
-
-Lemma nearest_round_mul_abs_nat (r : R) (n : nat) :
- `|nearest_round (r *+ n) - nearest_round(r)*+n| <= (n.+1)/2.
-Proof.
- generalize (nearest_round_nat_mul_bound_R r n); intros.
- generalize (nearest_round_bound_O (r *+ n)); intros.
- rewrite /= Rabs_opp_sym in H0.
- generalize (lerD H H0); intros.
- generalize (ler_normD
- ((nearest_round r)%:~R *+ n + - r *+ n)%R
- (r *+ n - (nearest_round (r *+ n))%:~R)%R ); intros.
- generalize (Order.POrderTheory.le_trans H2 H1); intros.
- rewrite -!addrA (addrA (-r*+n) _ _) in H3.
- rewrite mulNrn addNr add0r in H3.
- rewrite distnC.
- rewrite -mulrSr -rmorphMn -rmorphN -rmorphD /= in H3.
- by apply Rabs_absz_half in H3.
-Qed.
-
-Lemma nearest_round_mul_abs_nat_alt (r : R) (n : nat) :
- `|nearest_round (r *+ n) - nearest_round(r)*+n| <= (n.+1)./2.
-Proof.
- generalize (nearest_round_nat_mul_bound_R r n); intros.
- generalize (nearest_round_bound_O (r *+ n)); intros.
- rewrite /= Rabs_opp_sym in H0.
- generalize (lerD H H0); intros.
- generalize (ler_normD
- ((nearest_round r)%:~R *+ n + - r *+ n)%R
- (r *+ n - (nearest_round (r *+ n))%:~R)%R ); intros.
- generalize (Order.POrderTheory.le_trans H2 H1); intros.
- rewrite -!addrA (addrA (-r*+n) _ _) in H3.
- rewrite mulNrn addNr add0r in H3.
- rewrite distnC.
- rewrite -mulrSr -rmorphMn -rmorphN -rmorphD /= in H3.
- by apply Rabs_absz_half_le in H3.
-Qed.
-
-Lemma nearest_round_0 :
- nearest_round 0 = 0.
-Proof.
- rewrite /nearest_round /ran_round upi0 oppr0 addr0.
- replace (((1%:~R : R )< 1 / 2)%O) with false by lra.
- lia.
-Qed.
-
-Lemma nearest_round_int_val (n : int) :
- nearest_round (n %:~R) = n.
-Proof.
- rewrite /nearest_round /ran_round.
- generalize (upi_intl n 0); intros.
- rewrite Rplus_add addr0 upi0 in H.
- rewrite H intrD addrC addrA addNr add0r.
- replace ((1%:~R : R) < 1 / 2)%O with false by lra.
- lia.
-Qed.
-
-Lemma nearest_round_half (r : R) :
- (upi r)%:~R - r = 1 / 2 <->
- Rabs ((nearest_round r)%:~R - r) = 1 / 2.
-Proof.
- rewrite /nearest_round /ran_round.
- split; intros.
- - rewrite H.
- assert ((((1 / 2):R) < 1 / 2)%O = false) by lra.
- rewrite H0 rmorphD /= /Rminus Rplus_add Ropp_opp.
- rewrite -addrA addrC -addrA (addrC (- r) _) H.
- replace ( ((-1)%:~R : R) + 1 / 2) with (- (1 / 2):R) by lra.
- rewrite -Ropp_opp Rabs_Ropp Rabs_right // mul1r.
- apply Rle_ge; left.
- apply /RltP.
- rewrite invr_gt0.
- lra.
- - case: (boolP ((upi r)%:~R - r < 1 / 2)%O); intros.
- + rewrite p in H.
- rewrite /Rminus Ropp_opp Rplus_add in H.
- rewrite Rabs_right // in H.
- left.
- apply upi_bound.
- + assert (((upi r)%:~R - r < 1 / 2)%O = false) by lra.
- rewrite H0 in H.
- rewrite rmorphD rmorphN /= in H.
- destruct (upi_bound_O r).
- rewrite /Rminus Ropp_opp Rplus_add in H.
- case: (boolP (Order.le (0 : R) ((upi r)%:~R - 1%:~R - r)%R )); intros.
- * move /RleP in p.
- rewrite Rabs_right in H.
- -- move /RleP in p.
- lra.
- -- by apply Rle_ge.
- * rewrite -Order.TotalTheory.ltNge in i0.
- rewrite Rabs_left in H.
- -- rewrite Ropp_opp in H.
- lra.
- -- by apply /RltP.
-Qed.
-
-Lemma nearest_round_not_half (r : R) :
- (upi r)%:~R - r <> 1 / 2 ->
- (Rabs ((nearest_round r)%:~R - r) < 1 / 2)%O.
-Proof.
- intros.
- rewrite Order.POrderTheory.lt_def.
- apply /andP.
- split.
- - apply /eqP.
- intros ?.
- symmetry in H0.
- by rewrite -nearest_round_half in H0.
- - apply nearest_round_bound_O.
-Qed.
-
-Lemma nearest_round_not_half' (r : R) :
- (Rabs ((nearest_round r)%:~R - r) <> 1 / 2) ->
- (Rabs ((nearest_round r)%:~R - r) < 1 / 2)%O.
-Proof.
- intros.
- rewrite Order.POrderTheory.lt_def.
- apply /andP.
- split.
- - lra.
- - apply nearest_round_bound_O.
-Qed.
-
-Lemma nearest_round_half_val (r : R) :
- Rabs ((nearest_round r)%:~R - r) = 1 / 2 ->
- (nearest_round r)%:~R - r = -(1 / 2).
-Proof.
- intros.
- rewrite -nearest_round_half in H.
- rewrite /nearest_round /ran_round.
- rewrite H.
- case: (boolP (Order.lt _ _)); lra.
-Qed.
-
-Lemma nearest_round_mul_abs_nat_half (r : R) (n : nat) :
- Rabs ((nearest_round r)%:~R - r) = 1 / 2 ->
- Rabs ((nearest_round (r *+ n.+1))%:~R - r *+ n.+1)%R = 1 / 2 ->
- (`|nearest_round (r *+ n.+1) - nearest_round(r)*+n.+1| <= n./2)%O.
-Proof.
- intros.
- rewrite /Rminus Ropp_opp Rplus_add in H.
- apply nearest_round_half_val in H0.
- apply nearest_round_half_val in H.
- apply (f_equal (fun z => z *+ n.+1)) in H.
- assert (((nearest_round (r *+ n.+1))%:~R:R) - (nearest_round r)%:~R *+ n.+1 = -(- (1 / 2) *+ n.+1) - (1 / 2)).
- {
- rewrite -H -H0.
- ring.
- }
- replace ( - (1 / 2) *- n.+1 - 1 / 2) with (((1 / 2) *+ n):R) in H1 by ring.
- rewrite -rmorphMn -rmorphN -rmorphD /= in H1.
- apply (f_equal (fun z => Rabs z)) in H1.
- rewrite (Rabs_right (1 / 2 *+ n)) in H1.
- - move /eqP in H1.
- rewrite Order.POrderTheory.eq_le in H1.
- move /andP in H1.
- destruct H1.
- by apply Rabs_absz_half_le in H1.
- - apply Rle_ge.
- rewrite -mulrnAl.
- apply /RleP.
- apply mulr_ge0; [|lra].
- apply mulrn_wge0; lra.
- Qed.
-
-Lemma nearest_round_mul_abs_nat_not_half1 (r : R) (n : nat) :
- Rabs ((nearest_round r)%:~R - r) <> 1 / 2 ->
- `|nearest_round (r *+ n.+1) - nearest_round(r)*+n.+1| <= n.+1./2.
-Proof.
- intros.
- generalize (nearest_round_nat_mul_bound_R'_S r n H); intros.
- generalize (nearest_round_bound_O (r *+ n.+1)); intros.
- simpl in H1.
- rewrite Rabs_opp_sym in H1.
- generalize (ltr_leD H0 H1); intros.
- generalize (ler_normD
- ((nearest_round r)%:~R *+ n.+1 + - r *+ n.+1)%R
- (r *+ n.+1 - (nearest_round (r *+ n.+1))%:~R)%R ); intros.
- generalize (Order.POrderTheory.le_lt_trans H3 H2); intros.
- rewrite -addrA (addrA _ (r *+ n.+1) _) in H4.
- replace (- r *+ n.+1 + r *+ n.+1) with (0:R) in H4 by ring.
- rewrite add0r -mulrSr in H4.
- rewrite distnC.
- rewrite -rmorphMn -rmorphN -rmorphD /= in H4.
- by apply Rabs_absz_half_lt in H4.
-Qed.
-
-Lemma nearest_round_mul_abs_nat_not_half2 (r : R) (n : nat) :
- Rabs ((nearest_round (r *+ n.+1))%:~R - r *+ n.+1)%R <> 1 / 2 ->
- `|nearest_round (r *+ n.+1) - nearest_round(r)*+n.+1| <= n.+1./2.
-Proof.
- intros.
- generalize (nearest_round_nat_mul_bound_R r n.+1); intros.
- generalize (nearest_round_bound_O' (r *+ n.+1) H); intros.
- rewrite mulNrn Rabs_opp_sym in H0.
- generalize (ltr_leD H1 H0); intros.
- generalize (ler_normD
- ((nearest_round (r *+ n.+1))%:~R - r *+ n.+1)%R
- (r *+ n.+1 - (nearest_round r)%:~R *+ n.+1)%R ); intros.
- generalize (Order.POrderTheory.le_lt_trans H3 H2); intros.
- rewrite -addrA (addrA _ (r *+ n.+1) _) in H4.
- replace (r *- n.+1 + r *+ n.+1) with (0:R) in H4 by ring.
- rewrite add0r (addrC (1 / 2) _) -mulrSr in H4.
- rewrite -rmorphMn -rmorphN -rmorphD /= in H4.
- by apply Rabs_absz_half_lt in H4.
-Qed.
-
-Lemma nearest_round_mul_abs_nat_0 (r : R) :
- `|nearest_round (r *+ 0) - nearest_round(r)*+0| = 0%N.
-Proof.
- by rewrite mulr0n nearest_round_0 mulr0n oppr0 addr0 /=.
-Qed.
-
-Lemma nearest_round_mul_abs_nat' (r : R) (n : nat) :
- `|nearest_round (r *+ n) - nearest_round(r)*+n| <= n./2.
-Proof.
- destruct n.
- - rewrite nearest_round_mul_abs_nat_0.
- lia.
- - case: (boolP (Rabs ((nearest_round r)%:~R - r) == 1 / 2)).
- + case: (boolP (Rabs ((nearest_round (r *+ n.+1))%:~R - r *+ n.+1)%R == 1 / 2)).
- * intros.
- move /eqP in p.
- move /eqP in p0.
- generalize (nearest_round_mul_abs_nat_half r n p0 p).
- lia.
- * intros.
- move /eqP in i.
- by apply nearest_round_mul_abs_nat_not_half2.
- + intros.
- move /eqP in i.
- by apply nearest_round_mul_abs_nat_not_half1.
-Qed.
-
-Lemma nearest_round_mul_abs_nat_opp (r : R) (n : nat) :
- `|nearest_round r *+ n + nearest_round (r *- n)| <=
- n.+1 / 2.
-Proof.
- replace (r *-n) with ((opp r) *+ n) by ring.
- generalize (nearest_round_nat_mul_bound_R r n); intros.
- generalize (nearest_round_bound_O ((opp r) *+ n)); intros.
- simpl in H0.
- replace (opp (natmul (opp r) n)) with (natmul r n) in H0 by ring.
- generalize (lerD H H0); intros.
- generalize (ler_normD
- ((nearest_round r)%:~R *+ n + - r *+ n)%R
- ((nearest_round (- r *+ n))%:~R + r *+ n)%R); intros.
- generalize (Order.POrderTheory.le_trans H2 H1); intros.
- rewrite -addrA (addrC (-r *+ n) _) -addrA in H3.
- replace (r *+ n + (-r) *+n) with (0 : R) in H3 by ring.
- rewrite addr0 -mulrSr in H3.
- rewrite -rmorphMn -rmorphD /= in H3.
- by apply Rabs_absz_half in H3.
-Qed.
-
-Lemma nearest_round_mul_abs_nat_opp_not_half1 (r : R) (n : nat) :
- Rabs ((nearest_round r)%:~R - r) <> 1 / 2 ->
- `|nearest_round r *+ n.+1 + nearest_round (r *- n.+1)| <= n.+1 ./2.
-Proof.
- intros.
- generalize (nearest_round_nat_mul_bound_R'_S r n H); intros.
- generalize (nearest_round_bound_O ((opp r) *+ n.+1)); intros.
- rewrite mulNrn in H0.
- rewrite mulNrn opprK /= in H1.
- generalize (ltr_leD H0 H1); intros.
- generalize (ler_normD
- ((nearest_round r)%:~R *+ n.+1 - r *+ n.+1)%R
- ((nearest_round (r *- n.+1))%:~R + r *+ n.+1)%R); intros.
- generalize (Order.POrderTheory.le_lt_trans H3 H2); intros.
- rewrite -addrA (addrC _(r *+ n.+1)) (addrA (r*-n.+1) _ _) addNr add0r in H4.
- rewrite -mulrSr -rmorphMn -rmorphD /= in H4.
- by apply Rabs_absz_half_lt in H4.
-Qed.
-
-Lemma nearest_round_mul_abs_nat_opp_not_half2 (r : R) (n : nat) :
- Rabs ((nearest_round (- r *+ n.+1))%:~R - - r *+ n.+1)%R <> 1 / 2 ->
-
- `|nearest_round r *+ n.+1 + nearest_round (r *- n.+1)| <= n.+1 ./2.
-Proof.
- intros.
- generalize (nearest_round_nat_mul_bound_R r n.+1); intros.
- generalize (nearest_round_bound_O' ((opp r) *+ n.+1) H); intros.
- rewrite mulNrn in H0.
- rewrite mulNrn opprK in H1.
- generalize (ltr_leD H1 H0); intros.
- generalize (ler_normD
- ((nearest_round (r *- n.+1))%:~R + r *+ n.+1)%R
- ((nearest_round r)%:~R *+ n.+1 - r *+ n.+1)%R); intros.
- generalize (Order.POrderTheory.le_lt_trans H3 H2); intros.
- rewrite -addrA in H4.
- rewrite (addrC ((nearest_round r)%:~R *+ n.+1) _) in H4.
- rewrite (addrA (r*+n.+1) _ _) addrN add0r in H4.
- rewrite -rmorphMn -rmorphD -mulrS in H4.
- apply Rabs_absz_half_lt in H4.
- lia.
-Qed.
-
-(*
- r = IZR (Int_part r) + frac_part r.
- 0 <= frac_part r < 1.
-*)
-
-Lemma up_opp_Z (r : R) :
- IZR (up r) = Rplus r 1 ->
- up (-r) = Zplus (- (up r)) (Zpos 2).
-Proof.
- intros.
- symmetry.
- apply tech_up.
- - rewrite plus_IZR opp_IZR H.
- coq_lra.
- - rewrite plus_IZR opp_IZR H.
- coq_lra.
-Qed.
-
-Lemma up_opp_nZ (r : R) :
- IZR (up r) <> Rplus r 1 ->
- up (-r) = Zplus (- (up r)) (Zpos 1).
-Proof.
- intros.
- symmetry.
- destruct (archimed r).
- apply tech_up.
- - rewrite plus_IZR opp_IZR.
- coq_lra.
- - rewrite plus_IZR opp_IZR.
- coq_lra.
-Qed.
-
-Lemma upi_opp_int (r : R) :
- (upi r)%:~R = r + 1 ->
- upi (opp r) = - (upi r) + 2.
-Proof.
- intros.
- rewrite /upi up_opp_Z.
- - lia.
- - by rewrite -IZRE /add /one /= in H.
-Qed.
-
-Lemma upi_opp_nint (r : R) :
- (upi r)%:~R <> r + 1 ->
- upi (opp r) = - (upi r) + 1.
-Proof.
- intros.
- rewrite /upi up_opp_nZ.
- - lia.
- - by rewrite -IZRE /add /one /= in H.
-Qed.
-
-Lemma nearest_round_opp_not_half (r : R) :
- (upi r)%:~R - r <> 1/2 ->
- nearest_round (opp r) = - nearest_round r.
-Proof.
- rewrite /nearest_round /ran_round.
- case: (boolP ((upi r)%:~R == r + 1)); intros.
- - move /eqP in p.
- rewrite (upi_opp_int r p) opprK rmorphD rmorphN /= p.
- case: (boolP (Order.lt _ _)); intros.
- + by assert false by lra.
- + case: (boolP (Order.lt _ _)); intros.
- * by assert false by lra.
- * ring.
- - move /eqP in i.
- rewrite (upi_opp_nint r i) opprK rmorphD rmorphN /=.
- case: (boolP (Order.lt _ _)); intros.
- + case: (boolP (Order.lt _ _)); intros.
- * by assert false by lra.
- * ring.
- + case: (boolP (Order.lt _ _)); intros.
- * ring.
- * by assert false by lra.
-Qed.
-
-Lemma nearest_round_opp_half (r : R) :
- (upi r)%:~R - r = 1/2 ->
- nearest_round (-r) = - nearest_round r - 1.
-Proof.
- rewrite /nearest_round /ran_round.
- intros.
- assert( (upi r)%:~R <> r + 1 ) by lra.
- rewrite (upi_opp_nint r H0) rmorphD rmorphN /= opprK.
- case : (boolP (Order.lt _ _)); intros.
- - by assert false by lra.
- - case : (boolP (Order.lt _ _)); intros.
- + by assert false by lra.
- + ring.
-Qed.
-
-Lemma upi_mul_abs (r : R) (n : int) :
- `|upi (r *~ n) - (upi r) *~n| <= `|n|+1.
-Proof.
- destruct n; simpl.
- - rewrite -!pmulrn.
- generalize (upi_nat_mul_abs r n); intros.
- lia.
- - rewrite distnC /intmul /=.
- generalize (upi_nat_mul_abs_S r n); intros.
- rewrite distnC in H.
- case: (boolP ((upi (r *+ n.+1))%:~R == (r*+n.+1)+1)); intros.
- + rewrite (upi_opp_int _ (eqP p)) opprD opprK addrC distnC.
- rewrite opprB (addrC 2 _) addrA.
- generalize (absz_triang (upi r *+ n.+1 - upi (r *+ n.+1)) 2); intros.
- lia.
- + move /eqP in i.
- rewrite (upi_opp_nint _ i) opprD opprK addrC distnC.
- rewrite opprB (addrC 1 _) addrA.
- generalize (absz_triang (upi r *+ n.+1 - upi (r *+ n.+1)) 1); intros.
- lia.
-Qed.
-
-Lemma upi_mul_abs_alt (r : R) (n : int) :
- `|(upi (mul r n%:~R) - (upi r) * n)%Z| <= `|n|+1.
-Proof.
- generalize (upi_mul_abs r n); intros.
- replace (upi (mul r n%:~R) - upi r * n)%R with
- (upi (r *~ n) - upi r *~ n); trivial.
- f_equal.
- - f_equal; lra.
- - f_equal; f_equal; lia.
-Qed.
-
-Lemma nearest_round_mul_abs (r : R) (n : int) :
- `|nearest_round (r *~ n) - nearest_round(r)*~n| <= (`|n|+1)/2.
-Proof.
- destruct n.
- - rewrite -!pmulrn.
- generalize (nearest_round_mul_abs_nat r n); intros.
- replace (`|n| + 1)%N with (n.+1) by lia.
- lia.
- - rewrite distnC /intmul NegzE.
- replace (`|-(Posz n.+1)|+1)%N with (n.+2) by lia.
- rewrite -opprD abszN.
- apply nearest_round_mul_abs_nat_opp.
-Qed.
-
-Lemma nearest_round_int_mul (n1 n2 d : int) :
- d <> 0 ->
- `|(nearest_round_int (n1 * n2) d - nearest_round_int n1 d * n2)%R| <= (`|n2| + 1)/2.
-Proof.
- rewrite /nearest_round_int intrM.
- rewrite (_: (n1%:~R * n2%:~R / d%:~R) = ((n1%:~R / d%:~R)%R * n2%:~R )); last by lra.
- move: {n1} (((n1%:~R / d%:~R)%R)) => n1 _ {d}.
- replace (n1 * n2%:~R) with (n1*~n2) by ring.
- replace (nearest_round n1 * n2) with (nearest_round n1 *~ n2).
- apply nearest_round_mul_abs.
- destruct n2; ring.
-Qed.
-
-Lemma nearest_round_int_add2_ex (n1 n2 d : int) :
- d <> 0 ->
- let sum := nearest_round_int n1 d + nearest_round_int n2 d in
- { n3 : int |
- nearest_round_int (n1 + n2) d = sum + n3 /\
- `|n3| <= 1}.
-Proof.
- intros.
- exists (nearest_round_int (n1 + n2) d - sum).
- split; try lia.
- by apply nearest_round_int_add2.
-Qed.
-
-Lemma nearest_round_int_mul_ex (n1 n2 d : int) :
- d <> 0 ->
- { n3 : int |
- nearest_round_int (n1 * n2) d = nearest_round_int n1 d * n2 + n3 /\
- `|n3| <= (`|n2|+1)/2}.
-Proof.
- intros.
- exists (nearest_round_int (n1 * n2) d - nearest_round_int n1 d * n2).
- split; try lia.
- by apply nearest_round_int_mul.
-Qed.
-
-Definition div_round (a : {poly int}) (d : int) : {poly int} :=
- map_poly (fun c => nearest_round_int c d) a.
-
-Lemma div_round0 (den : int) :
- div_round (0 : {poly int}) den = 0.
-Proof.
- apply map_poly0.
-Qed.
-
-Lemma nth_map_default:
- forall [T1 : Type] (x1 : T1) [T2 : Type] (x2 : T2) (f : T1 -> T2) [n : nat] [s : seq T1],
- f x1 = x2 ->
- nth x2 [seq f i | i <- s] n = f (nth x1 s n).
-Proof.
- intros.
- case/orP: (leqVgt (size s) n) => ineq.
- - by rewrite !nth_default // size_map.
- - by rewrite (nth_map x1).
-Qed.
-
-Lemma div_round_mul_add (a b : {poly int}) (d : int) :
- d <> 0 ->
- div_round (a + d *: b) d = div_round a d + b.
-Proof.
- intros.
- rewrite /div_round !map_polyE -polyP => i.
- rewrite coefD !coef_Poly.
- rewrite !(nth_map_default 0 0); try by rewrite nearest_round_int0.
- by rewrite coefD -nearest_round_int_mul_add_r // coefZ.
-Qed.
-
-Lemma div_round_muln_add (a b : {poly int}) (d : nat) :
- d <> 0%N ->
- div_round (a + b *+ d) d = div_round a d + b.
-Proof.
- intros.
- rewrite -div_round_mul_add; try lia.
- by rewrite -scaler_nat /GRing.scale /= !scale_polyE natz.
-Qed.
-
-Lemma div_round_muln (b : {poly int}) (d : nat) :
- d <> 0%N ->
- div_round (b *+ d) d = b.
-Proof.
- intros.
- generalize (div_round_muln_add 0 b d H); intros.
- by rewrite add0r div_round0 add0r in H0.
-Qed.
-
-Lemma div_round_muln_add_l (a b : {poly int}) (d : nat) :
- d <> 0%N ->
- div_round (b *+ d + a) d = b + div_round a d.
-Proof.
- intros.
- rewrite addrC div_round_muln_add //.
- by rewrite addrC.
-Qed.
-
-Lemma div_round_add2_ex (a b : {poly int}) (d : int) :
- d <> 0 ->
- { c : {poly int} |
- div_round (a + b) d = div_round a d + div_round b d + c /\
- icoef_maxnorm c <= 1}.
-Proof.
- exists (div_round (a + b) d - (div_round a d + div_round b d)).
- split.
- - ring.
- - rewrite /icoef_maxnorm /div_round.
- apply /bigmax_leqP => i _.
- rewrite !(coefD,coefN).
- rewrite !coef_map_id0; try by rewrite nearest_round_int0.
- rewrite !coefD.
- by apply nearest_round_int_add2.
-Qed.
-
-Definition div_round_add2_perturb (a b : {poly int}) (d : int) (dn0 : d <> 0) : {poly int}
- := sval (div_round_add2_ex a b d dn0).
-
-Lemma div_round_add2_eq (a b : {poly int}) (d : int) (dn0 : d <> 0)
- : div_round (a + b) d = div_round a d + div_round b d +
- div_round_add2_perturb a b d dn0.
-Proof.
- rewrite /div_round_add2_perturb.
- case: (div_round_add2_perturb a b d dn0).
- case: div_round_add2_ex; intros; simpl.
- tauto.
-Qed.
-
-Lemma div_round_add2_eq_alt (a b : {poly int}) (d : int) (dn0 : d <> 0) :
- div_round a d + div_round b d = div_round (a + b) d - div_round_add2_perturb a b d dn0.
-Proof.
- generalize (div_round_add2_eq a b d dn0); intros.
- rewrite H.
- ring.
-Qed.
-
-Lemma div_round_add2_perturb_small (a b : {poly int}) (d : int) (dn0 : d <> 0)
- : icoef_maxnorm (div_round_add2_perturb a b d dn0) <= 1.
-Proof.
- rewrite /div_round_add2_perturb.
- case: (div_round_add2_perturb a b d dn0).
- case: div_round_add2_ex; intros; simpl.
- tauto.
-Qed.
-
-Definition q_reduce (q : nat) (p : {poly int}) : {poly 'Z_q} :=
- map_poly (fun c => c%:~R) p.
-
-Lemma q_reduce_is_rmorphism (q : nat) :
- rmorphism (q_reduce q).
-Proof.
- apply map_poly_is_rmorphism.
-Qed.
-
-Canonical q_reduce_rmorphism (q : nat) := RMorphism (q_reduce_is_rmorphism q).
-
-Definition public_key {q : nat} (e s : {poly int}) (a : {poly 'Z_q}) : {poly {poly 'Z_q}} :=
- Poly [:: (- a * (q_reduce q s) + (q_reduce q e)); a].
-
-Definition FHE_encrypt {q : nat} (p : {poly 'Z_q}) (v e0 e1 : {poly int}) (evkey : {poly {poly 'Z_q}}) :=
- Poly [:: (p + q_reduce q e0); q_reduce q e1] + (q_reduce q v) *: evkey.
-
-Definition FHE_decrypt {q : nat} (s : {poly int}) (pp : {poly {poly 'Z_q}}) :=
- pp.[q_reduce q s].
-
-Lemma decrypt_encrypt {q : nat} (e s v e0 e1 : {poly int}) (a p : {poly 'Z_q}) :
- FHE_decrypt s (FHE_encrypt p v e0 e1 (public_key e s a)) =
- p + (q_reduce q e0) + q_reduce q e1 * q_reduce q s + q_reduce q v * q_reduce q e.
-Proof.
- rewrite /FHE_decrypt /FHE_encrypt /public_key.
- rewrite hornerD hornerZ !horner_Poly /= mul0r !add0r.
- ring.
-Qed.
-
-Lemma decrypt_add {q : nat} (P Q : {poly 'Z_q}) (PP QQ : {poly {poly 'Z_q}}) (s : {poly int}) :
- FHE_decrypt s PP = P ->
- FHE_decrypt s QQ = Q ->
- FHE_decrypt s (FHE_add PP QQ) = P + Q.
-Proof.
- rewrite /FHE_decrypt /FHE_add.
- intros.
- by rewrite hornerD H H0.
-Qed.
-
-Lemma decrypt_mult_base {q : nat} (P Q : {poly 'Z_q}) (PP QQ : {poly {poly 'Z_q}}) (s : {poly int}) :
- FHE_decrypt s PP = P ->
- FHE_decrypt s QQ = Q ->
- FHE_decrypt s (FHE_mult_base PP QQ) = P * Q.
-Proof.
- rewrite /FHE_decrypt /FHE_mult_base.
- intros.
- by rewrite hornerM H H0.
-Qed.
-
-Definition key_switch_key {q p : nat} (s s2 e : {poly int}) (a : {poly 'Z_(p*q)}) : {poly {poly 'Z_(p*q)}} :=
- Poly [:: (-a * (q_reduce (p * q) s) + (q_reduce (p * q) e) + q_reduce (p * q) (s2 *+ p)); a].
-
-Definition ev_key {q p : nat} (s e : {poly int}) (a : {poly 'Z_(p*q)}) :=
- key_switch_key s (exp s 2) e a.
-
-Definition linearize {q p : nat} (c0 c1 c2 : {poly 'Z_q})
- (evkey : {poly {poly 'Z_(p*q)}}) :=
- Poly [:: c0; c1] +
- map_poly (fun P => q_reduce q (div_round ((zlift c2) * (zlift P)) (p%:Z)))
- evkey.
-
-(* evkey = [:: (-a * (q_reduce (p * q) s) + (q_reduce (p * q) e) + (q_reduce (p * q) (s^2 *+ p))); a]
- *)
-Ltac notHyp P :=
- match goal with
- | [ _ : P |- _ ] => fail 1
- | _ => idtac
- end.
-
-Ltac extend_as_perturb P :=
- let t := type of P in
- notHyp t; generalize P;
- let x := (fresh "perturb_small") in
- intros x.
-
-Ltac perturb_zlift_facts
- := repeat match goal with
- | [|- context [zlift_add2_perturb ?a ?b ?c]] =>
- extend_as_perturb (zlift_add2_perturb_small a b c)
- end.
-
-Ltac perturb_div_round_facts
- := repeat match goal with
- | [|- context [div_round_add2_perturb ?a ?b ?c ?d]] =>
- extend_as_perturb (div_round_add2_perturb_small a b c d)
- end.
-
-Lemma reduce_prod1 (p q : nat) (a : {poly int}) :
- q_reduce (q) (a *+ p) = (q_reduce q a) *+ p.
-Proof.
- by rewrite rmorphMn.
-Qed.
-
-Lemma modn_mul2 (p q r: nat) :
- p %% q * r = (p * r) %[mod q].
-Proof.
- by rewrite modnMml.
-Qed.
-
-Lemma modn_mul2r (p q r: nat) :
- r * (p %% q) = r * p %[mod q].
-Proof.
- by rewrite modnMmr.
-Qed.
-
-Lemma modn_prod2 (p q n : nat) :
- ((n %% (p * q) * p) %% (p * q))%N = (n %% q * p)%N.
-Proof.
- rewrite modn_mul2 (mulnC p _).
- by rewrite muln_modl.
-Qed.
-
-Lemma modn_prod2r (p q n : nat) :
- ((p * (n %% (p * q))) %% (p * q))%N = (p * (n %% q))%N.
-Proof.
- by rewrite modn_mul2r muln_modr.
-Qed.
-
-Lemma reduce_prod2 (p q : nat) (a : int) :
- 1 < q ->
- 1 < p*q ->
- nat_of_ord ((a %:~R : 'Z_(p*q))*+ p) = ((a%:~R : 'Z_q) * p)%N.
-Proof.
- intros.
- destruct a; rewrite /intmul !Zp_mulrn mul1n /inZp /= !Zp_cast //.
- by apply modn_prod2.
- - rewrite modn_prod2 //.
- f_equal.
- rewrite !modnB; [|lia|lia|lia|lia].
- rewrite modnMl modnn !addn0.
- replace (modn (modn (S n) (muln p q)) q) with
- (modn (modn (S n) q) q); trivial.
- rewrite modn_mod modn_dvdm //.
- apply dvdn_mull.
- by apply dvdnn.
-Qed.
-
-Lemma div2_le (p q : nat) :
- (p <= q/2)%N = (2 * p <= q)%N.
-Proof.
- move: (Nat.divmod_spec q 1 0 1 (le_refl _)) => /=.
- case: (Nat.divmod q 1 0 1) => /= x u.
- rewrite !Nat.add_0_r.
- move=> [-> ubound].
- destruct u; lia.
-Qed.
-
-Lemma liftc_reduce_prod2_nat (p q : nat) (a : nat) :
- 1 < q ->
- 1 < p*q ->
- zliftc ((a %:~R : 'Z_(p*q))*+ p) = (zliftc (a%:~R : 'Z_q)) *+p.
-Proof.
- move=> qbig pqbig.
- rewrite /zliftc.
- have ->: ((a %:~R : 'Z_(p*q))*+ p <= p * q / 2) = (((a%:~R : 'Z_q)) <= q /2).
- {
- rewrite -mulr_natl !div2_le.
- rewrite /intmul !Zp_mulrn mul1n /inZp /= !Zp_cast //.
- rewrite [modn (S 0) (muln p q)]modn_small //.
- rewrite [modn (S 0) q]modn_small // !mul1n.
- rewrite modn_mul2 modn_prod2r.
- rewrite mulnA (mulnC 2%N p) -mulnA.
- rewrite leq_pmul2l //.
- lia.
- }
- case: leqP; rewrite reduce_prod2 //; lia.
-Qed.
-
-Lemma Nat_even_even (x : nat) : Nat.even x = ~~ odd x.
-Proof.
- induction x => //=.
- destruct x; [lia |].
- by rewrite -IHx Nat.even_succ Nat.negb_odd.
-Qed.
-
-Lemma Nat_odd_odd (x : nat) : Nat.odd x = odd x.
-Proof.
- by rewrite -Nat.negb_even Nat_even_even negbK.
-Qed.
-
-Lemma qodd_half (q : nat) :
- odd q ->
- (q = q/2 + q/2 + 1)%N.
-Proof.
- rewrite -!Nat.div2_div => qodd.
- rewrite [LHS]Nat.div2_odd Nat_odd_odd qodd /=.
- lia.
-Qed.
-
-Lemma liftc_neg0 (q : nat) (a : int) :
- 1 < q ->
- (a %:~R : 'Z_q) = 0 ->
- zliftc ((-a)%:~R : 'Z_q) = - zliftc (a %:~R : 'Z_q).
-Proof.
- intros.
- rewrite /zliftc.
- assert (((-a) %:~R : 'Z_q) = 0).
- {
- apply (f_equal (fun z => opp z)) in H0.
- rewrite -rmorphN /= in H0.
- by rewrite H0 oppr0.
- }
- rewrite H0 H1 /=; lia.
-Qed.
-
-Lemma liftc_neg_prop (q : nat) (a : 'Z_q) :
- (q = q/2 + q/2 + 1)%N ->
- q - a <= q/2 = ~~ (a <= q/2).
-Proof.
- lia.
-Qed.
-
-Lemma liftc_neg_prop_alt (q : nat) (a : 'Z_q) :
- 1 < q ->
- (q - a != a)%N ->
- q - a <= q/2 = ~~ (a <= q/2).
-Proof.
- intros.
- generalize (Nat.div_mod_eq q 2); intros.
- assert (2 * (q/2) <= q)%coq_nat by lia.
- destruct a.
- simpl.
- simpl in H0.
- rewrite Zp_cast // in i.
- case: (boolP (q - m <= q/2)); intros.
- - case (boolP (m <= q/2)); intros i0; rewrite i0; lia.
- - case (boolP (m <= q/2)); intros i1; rewrite i1; try tauto.
- assert ((q mod 2)%coq_nat < 2).
- {
- rewrite modulo_modn.
- by rewrite ltn_mod.
- }
- lia.
-Qed.
-
-Lemma Z_q_small {q:nat} (c : 'Z_q) :
- 1 < q ->
- c < q.
-Proof.
- move=> qbig.
- case: c => m i /=.
- by rewrite Zp_cast // in i.
-Qed.
-
-Lemma liftc_neg (q : nat) (a : int) :
- 1 < q ->
- odd q ->
- (* ~~ intdiv.dvdz q (2 * a) *)
- zliftc ((-a)%:~R : 'Z_q) = - zliftc (a %:~R : 'Z_q).
-Proof.
- move=> qbig qodd.
- rewrite /zliftc.
- move: (qodd_half q qodd) => qsum.
- move : (Z_q_small (a%:~R : 'Z_q) qbig) => asmall.
- case: (eqVneq (a %:~R : 'Z_q) 0)=>i; [by apply liftc_neg0|].
- assert (apos: 0 < (a%:~R : 'Z_q)).
- {
- have anneg: 0 <= (a%:~R : 'Z_q) by lia.
- by rewrite Order.NatOrder.ltn_def i anneg.
- }
-
- have ->: (((- a)%:~R : 'Z_q) <= q / 2) = (~~ ((a%:~R : 'Z_q) <= q / 2)).
- {
- rewrite rmorphN /= {1 3}Zp_cast //.
- rewrite modnB; [|lia|lia].
- assert (0 < (a%:~R : 'Z_q) %% q).
- {
- rewrite modn_small //.
- }
- rewrite modnn H mul1n addn0 modn_small //.
- by apply liftc_neg_prop.
- }
- rewrite rmorphN /=.
- case: leqP => /=; intros; rewrite {1 3}Zp_cast // modn_small //; lia.
-Qed.
-
-Lemma liftc_neg_alt (q : nat) (a : int) :
- 1 < q ->
- (q - (a%:~R : 'Z_q) != (a%:~R : 'Z_q))%N ->
- zliftc ((-a)%:~R : 'Z_q) = - zliftc (a %:~R : 'Z_q).
-Proof.
- move=> qbig not_ahalf.
- rewrite /zliftc.
- move : (Z_q_small (a%:~R : 'Z_q) qbig) => asmall.
- case: (eqVneq (a %:~R : 'Z_q) 0)=>i; [by apply liftc_neg0|].
- assert (apos: 0 < (a%:~R : 'Z_q)).
- {
- have anneg: 0 <= (a%:~R : 'Z_q) by lia.
- by rewrite Order.NatOrder.ltn_def i anneg.
- }
-
- have ->: (((- a)%:~R : 'Z_q) <= q / 2) = (~~ ((a%:~R : 'Z_q) <= q / 2)).
- {
- rewrite rmorphN /= {1 3}Zp_cast //.
- rewrite modnB; [|lia|lia].
- assert (0 < (a%:~R : 'Z_q) %% q).
- {
- rewrite modn_small //.
- }
- rewrite modnn H mul1n addn0 modn_small //.
- by apply liftc_neg_prop_alt.
- }
- rewrite rmorphN /=.
- case: leqP => /=; intros; rewrite {1 3}Zp_cast // modn_small //; lia.
-Qed.
-
-Lemma int_Zp_0 {q : nat} (a : int) :
- 1 < q ->
- (a%:~R : 'Z_q) = 0 ->
- intdiv.modz a q = 0.
-Proof.
- move => qbig.
- rewrite Zp_int //.
- move/(f_equal val).
- case: a => n ; rewrite /inZp /= Zp_cast //.
- - rewrite intdiv.modz_nat => ->; lia.
- - rewrite intdiv.modNz_nat; try lia.
- rewrite modnS.
- case: (boolP (q %| n.+1)%N) => pred a0.
- + suff: ((n %% q)%N = (q-1)%N) by lia.
- move /dvdnP in pred.
- case: pred => x xeq.
- have HH: (n + q = x * q + (q-1))%N by lia.
- have: ((n + q) %% q = (x * q + (q - 1)) %% q)%N.
- {
- by rewrite HH.
- }
- rewrite -modnDm -[RHS]modnDm modnn addn0 modnMl add0n modn_mod => ->.
- rewrite modn_mod modn_small; lia.
- + rewrite modn_small in a0; lia.
- Qed.
-
-Lemma int_Zp_0_alt {q : nat} (a : int) :
- 1 < q ->
- intdiv.modz a q = 0 ->
- (a%:~R : 'Z_q) = 0.
-Proof.
- move => qbig qmod.
- rewrite Zp_int //.
- destruct a.
- - rewrite intdiv.modz_nat in qmod.
- apply ord_inj.
- rewrite /= Zp_cast //.
- lia.
- - rewrite intdiv.modNz_nat in qmod; try lia.
- apply ord_inj.
- rewrite /= Zp_cast //.
- rewrite modnS.
- case: (boolP (q %| n.+1)%N) => pred.
- + by rewrite subn0 modnn.
- + assert ((n%% q)%N = q-1)%N by lia.
- rewrite H.
- replace ((q-1).+1) with q by lia.
- by rewrite subnn mod0n.
-Qed.
-
-Lemma int_Zp_0_iff (q : nat) (a : int) :
- 1 < q ->
- intdiv.modz a q = 0 <->
- (a%:~R : 'Z_q) = 0.
-Proof.
- move => qbig.
- split.
- - by apply int_Zp_0_alt.
- - by apply int_Zp_0.
-Qed.
-
-Lemma int_Zp_eq_iff (q : nat) (a1 a2 : int) :
- 1 < q ->
- intdiv.modz a1 q = intdiv.modz a2 q <->
- (a1%:~R : 'Z_q) = (a2%:~R : 'Z_q).
-Proof.
- split; intros.
- - assert (intdiv.modz (a1 - a2) q = 0).
- {
- by rewrite -intdiv.modzDm H0 intdiv.modzDm subrr intdiv.mod0z.
- }
- rewrite int_Zp_0_iff // in H1.
- rewrite rmorphD rmorphN /= in H1.
- apply (f_equal (fun z => z + a2%:~R)) in H1.
- rewrite add0r in H1.
- rewrite -H1.
- ring.
- - assert (((a1 - a2)%:~R : 'Z_q) = 0).
- {
- by rewrite rmorphD rmorphN /= H0 subrr.
- }
- rewrite -int_Zp_0_iff // in H1.
- generalize (intdiv.divz_eq (a1 - a2) q); intros.
- rewrite H1 addr0 in H2.
- apply (f_equal (fun z => z + a2)) in H2.
- replace (a1 - a2 + a2) with (a1) in H2 by lia.
- rewrite H2 -intdiv.modzDm intdiv.modzMl add0r.
- by rewrite intdiv.modz_mod.
-Qed.
-
-Lemma liftc_neg_alt_alt (q : nat) (a : int) :
- 1 < q ->
- intdiv.dvdz q (2 * a) \/
- zliftc ((-a)%:~R : 'Z_q) = - zliftc (a %:~R : 'Z_q).
-Proof.
- move=> qbig.
- case: (boolP (intdiv.dvdz q (2 * a))) => div; [by left|].
- right.
- apply liftc_neg_alt => //.
- move: div.
- apply contraNN => eqq.
- apply/intdiv.dvdzP.
- exists (intdiv.divz (2%Z * a) q).
- have: 2 * (a%:~R : 'Z_q) == 0.
- {
- move: eqq.
- move/eqP/(f_equal (fun x:nat => x + (a%:~R : 'Z_q)))%nat.
- have aqsmall: (a%:~R : 'Z_q) < q by apply Z_q_small.
- rewrite (_:(q - (a%:~R : 'Z_q) + (a%:~R : 'Z_q))%N = q); [| lia].
- move => eqq.
- suff: (a%:~R + a%:~R) == ( 0 : 'Z_q).
- {
- move/eqP => <-.
- apply/eqP.
- ring.
- }
- by rewrite /eq_op/= -eqq Zp_cast // modnn.
- }
- generalize (intdiv.divz_eq (2%Z*a) q); intros eqq2 a0.
- rewrite {1}eqq2.
- suff ->:(intdiv.modz (2%Z*a) q = 0).
- - by rewrite addr0.
- - rewrite -(intrM (Zp_ringType (Zp_trunc q)) (2%Z) a) in a0.
- move /eqP in a0.
- by apply int_Zp_0.
-Qed.
-
-Lemma liftc_reduce_prod2 (p q : nat) (a : int) :
- 1 < q ->
- 1 < p ->
- zliftc ((a %:~R : 'Z_(p*q))*+ p) = (zliftc (a%:~R : 'Z_q)) *+p.
-Proof.
- move=> qbig pbig.
- assert (pqbig : 1 < p*q) by lia.
- rewrite /zliftc.
- have ->: ((a %:~R : 'Z_(p*q))*+ p <= p * q / 2) = (((a%:~R : 'Z_q)) <= q /2).
- {
- rewrite -mulr_natl !div2_le.
- destruct a; rewrite /intmul !Zp_mulrn mul1n /inZp /= !Zp_cast //.
- - rewrite [modn (S 0) (muln p q)]modn_small //.
- rewrite [modn (S 0) q]modn_small // !mul1n.
- rewrite modn_mul2 modn_prod2r.
- rewrite mulnA (mulnC 2%N p) -mulnA.
- rewrite leq_pmul2l //.
- lia.
- - rewrite (modn_small pqbig) (modn_small qbig) !mul1n.
- assert (p < p*q).
- {
- replace (p) with (p*1)%N at 1 by lia.
- rewrite ltn_pmul2l //; lia.
- }
- rewrite (modn_small H) -muln_modr.
- rewrite mulnA (mulnC 2%N p) -mulnA.
- rewrite leq_pmul2l; [| lia].
- assert ((((p * q - n.+1 %% (p * q)) %% (p * q)) %% q) =
- ((q - n.+1 %% q) %% q))%N.
- {
- rewrite modn_muln_r //.
- rewrite modnB; [|lia|lia].
- rewrite modnMl addn0.
- rewrite modn_muln_r //.
- case: ltP; intros.
- - rewrite mul1n.
- assert (q - n.+1 %% q < q) by lia.
- by rewrite (modn_small H0).
- - rewrite mul0n sub0n.
- assert (n.+1 %% q = 0)%N by lia.
- by rewrite H0 subn0 modnn.
- }
- by rewrite H0.
- }
- case: leqP; rewrite reduce_prod2 //; lia.
-Qed.
-
-Lemma lift_reduce_prod2 (p q : nat) (a : {poly int}) :
- 1 < q ->
- 1 < p ->
- zlift (q_reduce (p * q) a *+ p) =
- zlift (q_reduce q a) *+p.
-Proof.
- intros.
- rewrite /zlift -polyP => i.
- rewrite coefMn !coef_map_id0 // coefMn.
- rewrite /q_reduce coef_map_id0 //.
- by apply liftc_reduce_prod2.
-Qed.
-
-Lemma zlift_valid {q : nat} (c : {poly 'Z_q}) :
- 1 < q ->
- q_reduce q (zlift c) = c.
-Proof.
- intros.
- rewrite /q_reduce /zlift -polyP => i.
- by rewrite !coef_map_id0 // zliftc_valid.
-Qed.
-
-Lemma linearize_prop_mult {q p : nat} (c2 : {poly 'Z_q})
- (s e : {poly int}) (a : {poly 'Z_(p*q)}) :
- let c2' := q_reduce (p*q) (zlift c2) in
- (map_poly (fun P => c2' * P) (ev_key s e a)).[q_reduce (p * q) s] =
- c2' * (q_reduce (p*q) (exp s 2) *+ p) + c2' * (q_reduce (p * q) e).
-Proof.
- rewrite /ev_key /key_switch_key.
- rewrite map_Poly_id0; [|by rewrite mulr0].
- rewrite horner_Poly /= mul0r add0r.
- rewrite !(zlift_add2_eq,mulrDr) rmorphMn /=.
- ring.
-Qed.
-
-Lemma linearize_prop_div {q p : nat} (c2 : {poly 'Z_q})
- (s e : {poly int}) (a : {poly 'Z_(p*q)}) :
- let c2' := q_reduce (p*q) (zlift c2) in
- q_reduce q (div_round (zlift ((map_poly (fun P => c2' * P) (ev_key s e a)).[q_reduce (p * q) s])) p) =
- q_reduce q (div_round (zlift (c2' * (q_reduce (p*q) (exp s 2) *+ p) + c2' * (q_reduce (p * q) e))) p).
-Proof.
- simpl.
- by rewrite linearize_prop_mult.
-Qed.
-
-Lemma q_reduce_0_div (q : nat) (qbig: 1 < q) (p : {poly int}) :
- q_reduce q p = 0 ->
- { e : {poly int} | p = e *+ q}.
-Proof.
- intros.
- exists (map_poly (fun c => intdiv.divz c q) p).
- apply polyP.
- intros ?.
- rewrite /q_reduce in H.
- rewrite coefMn coef_map_id0; [|by rewrite intdiv.div0z].
- rewrite [LHS](intdiv.divz_eq _ q).
- assert (intdiv.modz p`_x q = 0).
- {
- move: H.
- rewrite -(map_poly0 [eta intr]).
- move/polyP/(_ x).
- rewrite !coef_map_id0 //.
- rewrite coef0 /=.
- move/(f_equal val)=> /= HH.
- apply int_Zp_0 => //.
- by apply ord_inj.
- }
- lia.
-Qed.
-
-Lemma poly_Zq_muln_q {q : nat} (qbig:1 < q) (a : {poly 'Z_q}) :
- a *+ q = 0.
-Proof.
- rewrite -polyP => i.
- rewrite coefMn coef0 Zp_mulrn.
- apply ord_inj => /=.
- by rewrite {3}Zp_cast // modnMl.
-Qed.
-
-Lemma q_reduce_muln_q (q : nat) (qbig:1 < q) (a : {poly int}) :
- q_reduce q (a *+ q) = 0.
-Proof.
- by rewrite reduce_prod1 poly_Zq_muln_q.
-Qed.
-
-Lemma zlift_red (q : nat) (p : {poly int}) :
- 1 < q ->
- { e : {poly int} |
- zlift (q_reduce q p) = p + e *+q}.
-Proof.
- intros.
- assert (q_reduce q (zlift (q_reduce q p) - p) = 0).
- {
- rewrite rmorphD rmorphN /= zlift_valid //.
- ring.
- }
- apply q_reduce_0_div in H0 => //.
- destruct H0.
- exists x.
- rewrite -e.
- ring.
-Qed.
-Lemma linearize_prop_div2 {q p : nat} (qbig : 1 < q) (pbig : 1 < p) (c2 : {poly 'Z_q})
- (s e : {poly int}) (a : {poly 'Z_(p*q)}) :
- let c2' := q_reduce (p*q) (zlift c2) in
- { e2 : {poly int} |
- q_reduce q (div_round (zlift (c2' * (q_reduce (p*q) (exp s 2) *+ p) + c2' * (q_reduce (p * q) e))) p) =
- c2 * (q_reduce q (exp s 2)) + q_reduce q (div_round ((zlift c2) * e) p + e2)}.
-Proof.
- assert (pqbig: (1 < p * q)) by lia.
- assert (pno: (Posz p <> 0)) by lia.
- destruct (zlift_red (p * q) (zlift c2 * e) pqbig).
- exists
- (
- div_round_add2_perturb (zlift c2 * e) (x *+ (p * q)) p pno +
- div_round
- (zlift_add2_perturb pqbig (q_reduce (p * q) (zlift c2 * s ^+ 2) *+ p)
- (q_reduce (p * q) (zlift c2 * e))) p +
- div_round_add2_perturb
- (zlift (c2 * q_reduce q (s ^+ 2)) *+ p + zlift c2 * e + x *+ (p * q))
- (zlift_add2_perturb pqbig (q_reduce (p * q) (zlift c2 * s ^+ 2) *+ p)
- (q_reduce (p * q) (zlift c2 * e))) p pno).
- rewrite (zlift_add2_eq,mulrDr) //.
- rewrite div_round_add2_eq //.
- rewrite !rmorphD /= -rmorphMn -rmorphM /=.
- rewrite mulrnAr rmorphMn /=.
- rewrite lift_reduce_prod2 //.
- rewrite div_round_muln_add_l; [|lia].
- rewrite rmorphD /=.
- rewrite zlift_valid // rmorphM /= zlift_valid // -!addrA.
- f_equal.
- rewrite -rmorphM /=.
- rewrite e0.
- rewrite div_round_add2_eq // !rmorphD /= -!addrA.
- replace (q_reduce q (div_round (x *+ (p * q)) p)) with (0 : {poly 'Z_q}).
- - by rewrite add0r.
- - rewrite mulnC mulrnA.
- assert (p <> 0%N) by lia.
- by rewrite div_round_muln // q_reduce_muln_q.
-Qed.
-
-Lemma absz_triang_sum {n} (a : 'I_n -> int) :
- `|\sum_(j a (fintype.lift ord0 z))).
- lia.
-Qed.
-
-Lemma add_zero (a b : nat) :
- b = 0%N ->
- a + b <= a.
-Proof.
- lia.
-Qed.
-
-Lemma delta_maxnorm_prod (a b : {poly int}) :
- icoef_maxnorm (a * b) <= (size b) * icoef_maxnorm a * icoef_maxnorm b.
-Proof.
- rewrite /icoef_maxnorm.
- apply /bigmax_leqP.
- intros.
- rewrite coefMr.
- eapply leq_trans.
- apply absz_triang_sum.
- rewrite /= -mulnA.
- assert (\sum_(j < i.+1) `|(a`_(i - j) * b`_j)%R| <=
- \sum_(j < size b) `|(a`_(i - j) * b`_j)%R|)%N.
- {
- case: (boolP (i < size b)); intros.
- - replace (size b) with (i.+1 + (size b - i.+1))%nat by lia.
- rewrite big_split_ord /=.
- assert (0 <= \sum_(i0 < size b - i.+1) `|(a`_(i - (i.+1 + i0)) * b`_(i.+1 + i0))%R|).
- {
- apply big_ind; lia.
- }
- lia.
- - replace (i.+1) with (size b + (i.+1 - size b))%nat by lia.
- rewrite big_split_ord /=.
- apply add_zero.
- under eq_big_seq.
- + intros.
- rewrite abszM.
- rewrite [b`_ _]nth_default /=.
- * rewrite muln0.
- over.
- * simpl.
- lia.
- + by rewrite sum_nat_const muln0.
- }
- eapply leq_trans; [apply H0 |].
- move: (@big_sum_le_const _ (index_enum (ordinal_finType (size b)))
- (fun j => (absz
- (mul (R:=int_Ring) (nth (zero int_Ring) a (subn i j))
- (nth (zero int_Ring) b j))))
- ((\max_(j < size a) `|a`_j|) * (\max_(j < size b) `|b`_j|))).
- rewrite /index_enum /= -!enumT !size_enum_ord.
- apply => p _.
- rewrite abszM.
- apply leq_mul; simpl.
- clear H0.
- - case (ltnP (i - p) (size a)) => ineq.
- + eapply leq_trans; cycle 1.
- by apply leq_bigmax_seq; rewrite ?mem_enum.
- by rewrite (_:(nat_of_ord i - nat_of_ord p)%nat=(nat_of_ord ((Ordinal (ineq))))) ?leqnn.
- + assert (a`_(i - p) = 0).
- {
- by apply nth_default.
- }
- rewrite H0 /=.
- lia.
- - eapply leq_trans; cycle 1.
- by apply leq_bigmax_seq; rewrite ?mem_enum.
- easy.
-Qed.
-
-Lemma Rabs_mul (a b : R) :
- Rabs a * Rabs b = Rabs (a * b)%R.
-Proof.
- by rewrite /mul /= Rabs_mult.
-Qed.
-
-Lemma nearest_round_le (r1 r2 : R) :
- (r1 <= r2)%O ->
- (nearest_round r1 <= nearest_round r2)%O.
-Proof.
- intros.
- rewrite /nearest_round /ran_round.
- generalize (upi_le r1 r2 H); intros.
- case : (boolP (Order.lt _ _)); intros.
- - case : (boolP (Order.lt _ _)); intros; trivial.
- assert (upi r1 < upi r2)%O.
- {
- rewrite -(ltr_int R_numDomainType).
- rewrite -(ler_int R_numDomainType) in H0.
- lra.
- }
- lia.
- - case : (boolP (Order.lt _ _)); lia.
-Qed.
-
-Lemma Rmult_leb_compat_r (r r1 r2 : R) : ((0 : R) <= r)%O -> (r1 <= r2)%O -> (r1 * r <= r2 * r)%O.
-Proof.
- move/RlebP=> HH1.
- move/RlebP=> HH2.
- by apply/RlebP/Rmult_le_compat_r.
-Qed.
-
-Lemma nearest_round_int_le (r1 r2 d : int) :
- ((0 : int) <= d)%O ->
- (r1 <= r2)%O ->
- (nearest_round_int r1 d <= nearest_round_int r2 d)%O.
-Proof.
- intros.
- rewrite /nearest_round_int.
- apply nearest_round_le.
- apply Rmult_leb_compat_r.
- - rewrite invr_ge0.
- rewrite (_:((0 : R) = 0%:~R)); last by lra.
- by rewrite ler_int.
- - by rewrite ler_int.
-Qed.
-
-Lemma nearest_round_pos (r : R) :
- ((0 : R) <= r)%O ->
- ((0 : int) <= nearest_round r)%O.
-Proof.
- intros.
- rewrite -nearest_round_0.
- by apply nearest_round_le.
-Qed.
-
-Lemma nearest_round_sgn_pos (r : R) :
- ((0 : R) <= r)%O ->
- ((0 : int) <= nearest_round_sgn r)%O.
-Proof.
- intros.
- rewrite /nearest_round_sgn.
- case: (boolP (Order.lt _ _)); intros.
- - assert ((0 : R) < 0)%O by lra.
- by rewrite Order.POrderTheory.ltxx in H0.
- - apply nearest_round_pos.
- lra.
-Qed.
-
-Lemma nearest_round_neg (r : R) :
- (r <= 0 )%O ->
- (nearest_round r <= 0)%O.
-Proof.
- intros.
- rewrite -nearest_round_0.
- by apply nearest_round_le.
-Qed.
-
-Lemma nearest_round_sgn_neg (r : R) :
- (r <= 0)%O ->
- (nearest_round_sgn r <= 0)%O.
-Proof.
- intros.
- rewrite /nearest_round_sgn.
- case: (boolP (Order.lt _ _)); intros.
- - assert (-r >= (0 : R))%O by lra.
- apply nearest_round_pos in H0.
- rewrite Ropp_opp.
- lra.
- - assert (r = 0) by lra.
- rewrite H0 nearest_round_0.
- lra.
-Qed.
-
-Lemma nearest_round_int_pos (p : nat) (c : int) :
- p != 0%N ->
- ((0 : int) <= c)%O ->
- ((0 : int) <= nearest_round_int c p)%O.
-Proof.
- intros.
- rewrite -(nearest_round_int0 p).
- apply nearest_round_le.
- rewrite mul0r.
- by apply div_int_pos.
-Qed.
-
-Lemma nearest_round_int_neg (p : nat) (c : int) :
- p != 0%N ->
- ((0 : int) >= c)%O ->
- ((0 : int) >= nearest_round_int c p)%O.
-Proof.
- intros.
- rewrite -(nearest_round_int0 p).
- apply nearest_round_le.
- rewrite mul0r.
- generalize div_int_neg; intros.
- by apply div_int_neg'.
-Qed.
-
-Lemma odd_mul_half (p : nat) :
- odd p ->
- ~(exists (c : int),
- c %:~R = (1/2 : R) *+ p).
-Proof.
- intros ??.
- destruct H0.
- apply (f_equal (fun z => 2 * z)) in H0.
- assert ((2 * (x %:~R) :R) <> p %:~R).
- {
- replace (2 * x %:~R) with ((2 * x)%:~R : R) by lra.
- intros ?.
- move /eqP in H1.
- rewrite eqr_int in H1.
- destruct x.
- - assert (2 * n = p)%N by lia.
- rewrite -H2 in H.
- replace (2 * n)%N with (n.*2) in H by lia.
- by rewrite odd_double in H.
- - lia.
- }
- rewrite H0 mulrnAr mulrA mulr1 in H1.
- by rewrite divff in H1; [|lra].
-Qed.
-
-Lemma odd_div_upi_not_half (p : nat) (c : int) :
- odd p ->
- (((upi (c%:~R / p%:~R)%R)%:~R : R) - c%:~R / p%:~R != (1 / 2 : R))%O.
-Proof.
- intros.
- apply /eqP.
- intros ?.
- apply (f_equal (fun z => z *+ p)) in H0.
- assert (exists (b : int),
- b %:~R = (1/2 : R) *+ p).
- {
- exists ((upi (c%:~R / p%:~R)%R)*+p - c).
- rewrite -H0.
- field.
- replace (zero (Ring.zmodType (ssrnum.Num.NumField.ringType R_numFieldType)))
- with (((0%:R):R)) by lra.
- rewrite eqr_nat.
- lia.
- }
- by generalize (odd_mul_half p H); intros.
-Qed.
-
-Lemma odd_div_upi_compare (p : nat) (c : int) :
- odd p ->
- (((upi (c%:~R / p%:~R)%R)%:~R : R) - c%:~R / p%:~R < (1 / 2 : R))%O ||
- (((upi (c%:~R / p%:~R)%R)%:~R : R) - c%:~R / p%:~R > (1 / 2 : R))%O.
-Proof.
- intros.
- generalize (odd_div_upi_not_half p c H); intros.
- apply (Order.TotalTheory.lt_total H0).
-Qed.
-
-Lemma nearest_round_int_negate (p : nat) (c : int) :
- odd p ->
- nearest_round_int (-c) p = - nearest_round_int c p.
-Proof.
- intros.
- rewrite /nearest_round_int.
- assert ((((- c)%:~R : R)/ p%:~R)%R = - (c %:~R / p%:~R)%R).
- {
- lra.
- }
- rewrite H0.
- apply nearest_round_opp_not_half.
- apply /eqP.
- by apply odd_div_upi_not_half.
-Qed.
-
-Lemma opp_Ropp (r : R) :
- opp r = Ropp r.
-Proof.
- by rewrite /opp /=.
-Qed.
-
-Lemma nearest_round_sgn_negate (r : R) :
- nearest_round_sgn (-r) = - nearest_round_sgn r.
-Proof.
- rewrite /nearest_round_sgn.
- rewrite -!opp_Ropp.
- case: (boolP (Order.lt r 0)); intros.
- - case: (boolP (Order.lt _ _)); intros; try lra.
- rewrite ltrNl oppr0 in p0.
- generalize (Order.POrderTheory.lt_trans p0 p); intros.
- by rewrite Order.POrderTheory.ltxx in H.
- - case: (boolP (Order.lt _ _)); intros.
- + by rewrite opprK.
- + assert (r = 0%R) by lra.
- by rewrite H oppr0 nearest_round_0 oppr0.
-Qed.
-
-Lemma nearest_round_int_sgn_negate (p : nat) (c : int) :
- nearest_round_int_sgn (-c) p = - nearest_round_int_sgn c p.
-Proof.
- rewrite /nearest_round_int_sgn.
- replace ((- c)%:~R / p%:~R)%R with
- (- (c%:~R / p%:~R)%R :R) by lra.
- by rewrite nearest_round_sgn_negate.
-Qed.
-
-Lemma nearest_round_int_sgn_nat (n d : nat) :
- d != 0%N ->
- nearest_round_int n d = nearest_round_int_sgn n d.
-Proof.
- rewrite /nearest_round_int /nearest_round_int_sgn /nearest_round_sgn.
- case: (boolP (Order.lt _ _)); trivial.
- intros.
- generalize (div_nat_pos n d H); intros.
- assert ((0 : R) < (0 : R))%O.
- {
- eapply Order.POrderTheory.le_lt_trans.
- apply H0.
- apply p.
- }
- by rewrite Order.POrderTheory.ltxx in H1.
-Qed.
-
-Lemma nearest_round_int_sgn_odd (d : nat) (n : int) :
- odd d ->
- nearest_round_int n d = nearest_round_int_sgn n d.
-Proof.
- intros.
- destruct n.
- - rewrite nearest_round_int_sgn_nat; trivial.
- apply odd_gt0 in H.
- lia.
- - rewrite NegzE.
- rewrite nearest_round_int_negate //.
- rewrite nearest_round_int_sgn_negate.
- rewrite nearest_round_int_sgn_nat; trivial.
- lia.
-Qed.
-
-Lemma nearest_round_sgn_le' (r1 r2 : R) :
- (Rabs r1 <= Rabs r2)%O ->
- (Posz `|nearest_round_sgn r1| <= Posz `|nearest_round_sgn r2|)%O.
-Proof.
- intros.
- rewrite /nearest_round_sgn.
- case: (boolP (Order.lt _ _)); intros.
- - rewrite Rabs_left in H; [|apply /RltP; rewrite /zero /= in p; lra].
- case: (boolP (Order.lt _ _)); intros.
- + rewrite Rabs_left in H; [|apply /RltP; rewrite /zero /= in p0; lra].
- rewrite !lez0_abs.
- * rewrite !opprK.
- by apply nearest_round_le.
- * assert ((0 : R) <= (- r2)%R)%O by lra.
- generalize (nearest_round_pos (- r2) H0); lra.
- * assert ((0 : R) <= (- r1)%R)%O by lra.
- generalize (nearest_round_pos (- r1) H0); lra.
- + rewrite Rabs_right in H; cycle 1.
- * rewrite /zero /= in i.
- apply Rle_ge.
- apply /RleP.
- lra.
- * rewrite abszN !gez0_abs.
- -- by apply nearest_round_le.
- -- apply nearest_round_pos; lra.
- -- apply nearest_round_pos.
- rewrite Ropp_opp; lra.
- - rewrite Rabs_right in H; cycle 1.
- + move /RltP in i.
- rewrite /zero /= in i.
- coq_lra.
- + case: (boolP (Order.lt _ _)); intros.
- * rewrite Rabs_left in H; [|apply /RltP; rewrite /zero /= in p; lra].
- rewrite abszN !gez0_abs.
- -- by apply nearest_round_le.
- -- apply nearest_round_pos; lra.
- -- apply nearest_round_pos; lra.
- * rewrite Rabs_right in H; cycle 1.
- -- rewrite /zero /= in i0.
- apply Rle_ge.
- apply /RleP.
- lra.
- -- rewrite !gez0_abs.
- ++ by apply nearest_round_le.
- ++ apply nearest_round_pos; lra.
- ++ apply nearest_round_pos; lra.
-Qed.
-
-Lemma nearest_round_sgn_le (r1 r2 : R) :
- (Rabs r1 <= Rabs r2)%O ->
- `|nearest_round_sgn r1| <= `|nearest_round_sgn r2|.
-Proof.
- intros.
- apply nearest_round_sgn_le' in H.
- lia.
-Qed.
-
-Lemma Rzero_0 :
- (0 : R) = IZR Z0.
-Proof.
- by rewrite /zero /=.
-Qed.
-
-Lemma nearest_round_int_sgn_le (n1 n2 d : int) :
- ((0 : int) < d)%O ->
- `|n1| <= `|n2| ->
- `|nearest_round_int_sgn n1 d| <= `|nearest_round_int_sgn n2 d|.
-Proof.
- intros.
- assert ((0 : R) < d%:~R)%O.
- {
- replace (0 : R) with (((0 : int)%:~R):R) by ring.
- by rewrite ltr_int.
- }
- assert (d%:~R != (0 : R))%O.
- {
- assert (d != 0) by lia.
- by rewrite -(eqr_int R_numDomainType) in H2.
- }
- rewrite /nearest_round_int_sgn.
- apply nearest_round_sgn_le.
- case: (boolP ((0 : int) <= n1)%O); intros.
- - rewrite Rabs_right; cycle 1.
- + apply Rle_ge.
- apply /RleP.
- rewrite -(ler_pM2r H1) mul0r -mulrA (mulrC _ d%:~R) divff; [|apply H2].
- rewrite mulr1.
- rewrite -(ler_int R_numDomainType) in p.
- lra.
- + case: (boolP ((0 : int) <= n2)%O); intros.
- * rewrite Rabs_right; cycle 1.
- -- apply Rle_ge.
- apply /RleP.
- rewrite -(ler_pM2r H1) mul0r -mulrA (mulrC _ d%:~R) divff; [|apply H2].
- rewrite mulr1.
- rewrite -(ler_int R_numDomainType) in p0.
- lra.
- -- rewrite -(ler_pM2r H1) -!mulrA !(mulrC _ d%:~R) !divff; [|apply H2].
- rewrite !mulr1 ler_int.
- lia.
- * rewrite Rabs_left; cycle 1.
- -- apply /RltP.
- rewrite -(ltr_pM2r H1) -!mulrA !(mulrC _ d%:~R) !divff; [|apply H2].
- rewrite mulr1 mulr0.
- rewrite -Order.TotalTheory.ltNge -(ltr_int R_numDomainType) in i.
- lra.
- -- rewrite -(ler_pM2r H1) Ropp_opp mulNr -!mulrA !(mulrC _ d%:~R) !divff; [|apply H2].
- rewrite !mulr1 -mulrNz ler_int.
- lia.
- - rewrite Rabs_left; cycle 1.
- + apply /RltP.
- rewrite -(ltr_pM2r H1) -!mulrA !(mulrC _ d%:~R) !divff; [|apply H2].
- rewrite mulr1 mulr0.
- rewrite -Order.TotalTheory.ltNge -(ltr_int R_numDomainType) in i.
- lra.
- + case: (boolP ((0 : int) <= n2)%O); intros.
- * rewrite Rabs_right; cycle 1.
- -- apply Rle_ge.
- apply /RleP.
- rewrite -(ler_pM2r H1) mul0r -mulrA (mulrC _ d%:~R) divff; [|apply H2].
- rewrite mulr1.
- rewrite -(ler_int R_numDomainType) in p.
- lra.
- -- rewrite -(ler_pM2r H1) Ropp_opp mulNr -!mulrA !(mulrC _ d%:~R) !divff; [|apply H2].
- rewrite !mulr1 -mulrNz ler_int.
- lia.
- * rewrite Rabs_left; cycle 1.
- -- apply /RltP.
- rewrite -(ltr_pM2r H1) -!mulrA !(mulrC _ d%:~R) !divff; [|apply H2].
- rewrite mulr1 mulr0.
- rewrite -Order.TotalTheory.ltNge -(ltr_int R_numDomainType) in i0.
- lra.
- -- rewrite -(ler_pM2r H1) Ropp_opp !mulNr -!mulrA !(mulrC _ d%:~R) !divff; [|apply H2].
- rewrite !mulr1 -!mulrNz ler_int.
- lia.
-Qed.
-
-Lemma nearest_round_leq (r : R) :
- (Rabs(r - (nearest_round r)%:~R) <= 1 / 2)%O.
-Proof.
- rewrite /nearest_round /ran_round.
- case_eq (Order.lt ((upi r)%:~R -r) (1 / 2)); intros.
- - rewrite Rabs_minus_sym.
- rewrite Rabs_right.
- + rewrite /Rminus Rplus_add Ropp_opp.
- lra.
- + destruct (upi_bound r).
- rewrite /Rminus Rplus_add Ropp_opp.
- apply Rle_ge.
- apply /RleP.
- move /RltP in H0.
- replace (IZR Z0) with (zero R_zmodType).
- * lra.
- * by rewrite /zero /=.
- - rewrite /Rminus Rplus_add Ropp_opp.
- rewrite Rabs_right.
- + lra.
- + destruct (upi_bound r).
- apply Rle_ge.
- apply /RleP.
- move /RleP in H1.
- replace (IZR Z0) with (zero R_zmodType).
- * lra.
- * by rewrite /zero /=.
-Qed.
-
-Lemma nearest_round_sgn_leq (r : R) :
- (Rabs(r - (nearest_round_sgn r)%:~R) <= 1 / 2)%O.
-Proof.
- rewrite /nearest_round_sgn.
- case: (boolP (Order.lt _ _)); intros.
- - generalize (nearest_round_leq (- r)); intros.
- rewrite -Rabs_Ropp in H.
- rewrite /Rminus Rplus_add !Ropp_opp in H.
- rewrite opprD !opprK in H.
- rewrite /Rminus Rplus_add !Ropp_opp.
- eapply Order.POrderTheory.le_trans; cycle 1.
- apply H.
- apply /RlebP.
- right.
- f_equal.
- ring.
- - apply nearest_round_leq.
-Qed.
-
-Lemma nearest_round_int_leq (p : nat) (a : int) :
- p != 0%N ->
- `|a - (nearest_round_int a p)*+p| <= p./2.
-Proof.
- intros.
- rewrite /nearest_round_int.
- generalize (nearest_round_leq (a%:~R / p%:~R)%R); intros.
- assert (Rabs ((a - nearest_round (a%:~R / p%:~R)%R *+ p)%:~R : R)<= p%:R / 2)%O.
- {
- rewrite -mulr_natl.
- assert ((0 : R) < p%:R)%O.
- {
- rewrite pmulrn_rgt0; [lia|lra].
- }
- rewrite -(ler_pM2l H1) in H0.
- assert (Rabs (p%:R) = p%:R).
- {
- rewrite Rabs_right //.
- apply Rle_ge.
- apply /RleP.
- replace (IZR Z0) with (zero R_zmodType); trivial.
- lra.
- }
- rewrite -{1}H2 Rabs_mul mulrDr in H0.
- assert (Rabs (a - nearest_round (a%:~R / p%:~R)%R *+ p)%:~R =
- Rabs (p%:R * (a%:~R / p%:~R) + p%:R * (- (nearest_round (a%:~R / p%:~R)%R)%:~R)%R)%R).
- {
- f_equal.
- field.
- lra.
- }
- rewrite -H3 in H0.
- rewrite -mulrA.
- apply H0.
- }
- assert ((Rabs (((a - nearest_round (a%:~R / p%:~R)%R *+ p)%:~R ) : R)) *+2 <= (p%:R/2) *+2)%O.
- {
- rewrite lerMn2r.
- apply /orP.
- by right.
- }
- replace (p%:R / 2 *+ 2) with ((p%:R):R) in H2 by (by field).
- case: (boolP ((0 : R) <= (a - nearest_round (a%:~R / p%:~R)%R *+ p)%:~R *+ 2)%O); intros.
- - rewrite Rabs_right in H2.
- + rewrite -rmorphMn /= in H2.
- replace (p%:R) with ((p%:~R):R) in H2 by ring.
- rewrite ler_int in H2.
- rewrite -rmorphMn ler0z in p0.
- assert ((0 : int) <= (a - nearest_round (a%:~R / p%:~R)%R *+ p))%O by lia.
- assert ((a - nearest_round (a%:~R / p%:~R)%R *+ p) <= p./2)%O by lia.
- by rewrite -(gez0_abs H3) in H4.
- + apply Rle_ge.
- apply /RleP.
- replace (IZR Z0) with (zero R_zmodType); trivial.
- lra.
- - rewrite Rabs_left in H2.
- + rewrite Ropp_opp -rmorphN -rmorphMn /= in H2.
- replace (p%:R) with ((p%:~R):R) in H2 by ring.
- rewrite ler_int in H2.
- rewrite -Order.TotalTheory.ltNge in i.
- assert (((a - nearest_round (a%:~R / p%:~R)%R *+ p)%:~R : R) < 0%:~R )%O by lra.
- rewrite ltr_int in H3.
- assert (-(a - nearest_round (a%:~R / p%:~R)%R *+ p) <= p./2)%O by lia.
- by rewrite -(ltz0_abs H3) in H4.
- + apply /RltP.
- replace (IZR Z0) with (zero R_zmodType); trivial.
- lra.
- Qed.
-
-Lemma div_round_leq (p : nat) (a : {poly int}) :
- p != 0%N ->
- icoef_maxnorm (a - (div_round a p)*+p) <= p./2.
-Proof.
- rewrite /icoef_maxnorm /div_round=> pn0.
- apply/bigmax_leqP=> i _.
- by rewrite coefB coefMn coef_map_id0 ?nearest_round_int0 // nearest_round_int_leq.
-Qed.
-
-
-Lemma div_round_eq (p : nat) (a : {poly int}) :
- p != 0%N ->
- { c : {poly int} |
- a = (div_round a p)*+p + c /\
- icoef_maxnorm c <= p./2}.
-Proof.
- intros.
- exists (a - (div_round a p)*+p).
- split.
- - ring.
- - by apply div_round_leq.
-Qed.
-
-
-Lemma icoef_maxnorm_triang (a b : {poly int}) :
- icoef_maxnorm (a + b) <= icoef_maxnorm a + icoef_maxnorm b.
-Proof.
- rewrite /icoef_maxnorm.
- apply/bigmax_leqP => i _.
- eapply leq_trans; first by rewrite coefD; apply absz_triang.
- apply leq_add.
- - case: (boolP (i < size a)); intros.
- + eapply leq_trans; cycle 1.
- apply leq_bigmax_seq; rewrite ?mem_index_enum //.
- by apply (@leq_trans `|a`_(Ordinal p)|).
- + rewrite nth_default //.
- by rewrite leqNgt i0.
- - case: (boolP (i < size b)); intros.
- + eapply leq_trans; cycle 1.
- apply leq_bigmax_seq; rewrite ?mem_index_enum //.
- by apply (@leq_trans `|b`_(Ordinal p)|).
- + rewrite nth_default //.
- by rewrite leqNgt i0.
-Qed.
-
-Lemma icoef_maxnorm_neg (a : {poly int}) :
- icoef_maxnorm (-a) = icoef_maxnorm a.
-Proof.
- rewrite /icoef_maxnorm size_opp.
- apply eq_big; trivial.
- intros.
- by rewrite coefE abszN.
-Qed.
-
-Lemma nearest_round_int_mulp (a p : int) :
- p != 0 ->
- a = nearest_round_int (p * a) p.
-Proof.
- intros.
- rewrite /nearest_round_int.
- replace ((p * a)%:~R / p%:~R)%R with (a %:~R : R).
- - by rewrite nearest_round_int_val.
- - field.
- move: H.
- apply contra_neq.
- apply (@intr_inj R_numDomainType p 0).
-Qed.
-
-Lemma nearest_round_int_le_const (p : nat) (a c : int) :
- Posz p != 0 ->
- (a <= c * p)%O ->
- (nearest_round_int a p <= c)%O.
-Proof.
- intros.
- rewrite (nearest_round_int_mulp c p H).
- apply nearest_round_int_le; trivial.
- by rewrite mulrC.
-Qed.
-
-Lemma nearest_round_int_le_const_nat (p a c : nat) :
- Posz p != 0 ->
- (a <= c * p) ->
- (nearest_round_int a p <= c)%O.
-Proof.
- intros.
- by apply nearest_round_int_le_const.
-Qed.
-
-Lemma nearest_round_int_odd_abs (p : nat) (a : int) :
- odd p ->
- `|nearest_round_int a p| = `|nearest_round_int `|a| p|.
-Proof.
- intros.
- destruct a.
- - by rewrite absz_nat.
- - rewrite NegzE nearest_round_int_negate //.
- by rewrite !abszN absz_nat.
-Qed.
-
-Lemma nearest_round_int_leq2 (c p : nat) (a : int) :
- odd p ->
- `|a| <= c * p ->
- `|nearest_round_int a p| <= c.
-Proof.
- intros.
- assert (Posz p != 0) by lia.
- generalize (nearest_round_int_le_const_nat _ _ _ H1 H0); intros.
- rewrite nearest_round_int_odd_abs //.
- assert (p != 0%N) by lia.
- assert ((0 : int) <= `|a|)%O by lia.
- generalize (nearest_round_int_pos p `|a| H3 H4); intros.
- rewrite {1}/absz.
- case E: (nearest_round_int `|a| p); lia.
-Qed.
-
-Lemma icoef_maxnorm_div_round_leq (c p : nat) (a : {poly int}) :
- odd p ->
- icoef_maxnorm a <= c * p ->
- icoef_maxnorm (div_round a p) <= c.
-Proof.
- rewrite /icoef_maxnorm /div_round.
- intros oodp pmax.
- apply /bigmax_leqP => i _.
- rewrite coef_map_id0.
- - apply nearest_round_int_leq2; trivial.
- move /bigmax_leqP in pmax.
- have pfle: (size (map_poly (aR:=int_Ring) (rR:=int_Ring) (nearest_round_int^~ p) a)) <= size a
- by rewrite size_poly.
- apply (pmax (widen_ord pfle i) isT).
- - by rewrite nearest_round_int0.
-Qed.
-
-Lemma div_round_maxnorm_le (p : nat) (a : {poly int}):
- odd p ->
- icoef_maxnorm (div_round a p) <= `|nearest_round_int (icoef_maxnorm a) p|.
-Proof.
- intros.
- rewrite /icoef_maxnorm /div_round.
- apply /bigmax_leqP; intros.
- rewrite coef_poly.
- case: (boolP (i < size a)); intros; try lia.
- rewrite nearest_round_int_odd_abs //.
- assert (nearest_round_int `|a`_i| p <= nearest_round_int (\max_(j < size a) `|a`_j|) p)%O.
- {
- apply nearest_round_int_le.
- - lia.
- - have pfle: ((size (map_poly (aR:=int_Ring) (rR:=int_Ring) (nearest_round_int^~ p) a)) <= size a).
- by rewrite size_poly.
- by apply (@leq_bigmax_seq _ (index_enum (ordinal_finType (size a))) xpredT (fun j => `|a`_j|)%O
- (widen_ord pfle i)).
- }
- assert (p != 0%N) by lia.
- assert ((0 : int) <= nearest_round_int `|a`_i| p)%O.
- {
- by apply nearest_round_int_pos.
- }
- assert ((0 : int) <= nearest_round_int (\max_(j < size a) `|a`_j|) p)%O.
- {
- by apply nearest_round_int_pos.
- }
- lia.
-Qed.
-
-Lemma sum_bound (n B : nat) (F : nat -> nat) :
- (forall j, F j <= B) ->
- \sum_(j ineq.
- + by rewrite nth_default.
- + by apply: (@leq_bigmax_seq _ (index_enum (ordinal_finType (size a))) xpredT (fun j => `|a`_j|)%O (Ordinal ineq)).
- - case/orP: (leqVgt (size b) (i - i0)) => ineq.
- + by rewrite nth_default.
- + by apply: (@leq_bigmax_seq _ (index_enum (ordinal_finType (size b))) xpredT (fun j => `|b`_j|)%O (Ordinal ineq)).
- }
- rewrite sum_nat_const card_ord mulnA in H1.
- eapply leq_trans; try apply H1.
-
- case/orP: (leqVgt i.+1 (size a)) => ineq.
- - replace (size a) with (i.+1 + (size a - i.+1))%nat by lia.
- rewrite big_split_ord /=.
- lia.
- - rewrite (_:i.+1 = (size a + (i.+1 - size a))%N); last by lia.
- rewrite big_split_ord /=.
- suff ->: (\sum_(i0 < i.+1 - size a) `|(a`_(size a + i0) * b`_(i - (size a + i0)))%R|)%nat = 0%nat
- by lia.
- apply/eqP.
- rewrite sum_nat_eq0.
- apply/forall_inP => j _.
- rewrite nth_default /=; lia.
-Qed.
-
-Lemma mul_half2 (a b : nat) :
- a <= b ./2 <->
- 2 * a <= b .
-Proof.
- lia.
-Qed.
-
-Lemma mul_half3 (a b c : nat) :
- a <= b * c./2 ->
- 2 * a <= b * c.
-Proof.
- lia.
-Qed.
-
-Lemma mul_half4 (a p : nat) :
- odd p ->
- a * p./2 <= p * (a + 1)./2.
-Proof.
- intros.
- generalize (odd_halfK H); intros.
- apply (f_equal (fun z => z.+1)) in H0.
- replace (p.-1.+1) with p in H0 by lia.
- rewrite -{2}H0.
- case: (boolP (odd a)); intros.
- - generalize (odd_halfK p0); intros.
- apply (f_equal (fun z => z.+1)) in H1.
- replace (a.-1.+1) with a in H1 by lia.
- rewrite -{1}H1.
- apply (f_equal (fun z => (z + 1)./2)) in H1.
- rewrite -H1.
- replace (a./2.*2.+1 + 1)./2 with (a./2 + 1)%N by lia.
- lia.
- - apply even_halfK in i.
- rewrite -i.
- replace ((a./2.*2 + 1)./2) with (a./2) by lia.
- lia.
-Qed.
-
-Lemma div_round_mul_ex (p : nat) (a b : {poly int}) :
- odd p ->
- { c : {poly int} |
- (div_round a p) * b =
- (div_round (a * b) p) + c /\
- icoef_maxnorm c <= ((size b) * icoef_maxnorm b + 2)./2}.
-Proof.
- intros oddp.
- have pno: (p != 0%N) by lia.
- destruct (div_round_eq p a pno) as [?[??]].
- destruct (div_round_eq p (a * b) pno) as [?[??]].
- move /eqP in pno.
- apply (f_equal (fun z => z * b)) in H.
- rewrite mulrDl in H.
- rewrite {1}H mulrnAl in H1.
- apply (f_equal (fun z => z - x * b)) in H1.
- rewrite -!addrA subrr addr0 in H1.
- apply (f_equal (fun z => div_round z p)) in H1.
- rewrite div_round_muln // addrC in H1.
- rewrite div_round_muln_add // addrC in H1.
- exists (div_round (x0 - x * b) p).
- split; trivial.
- assert (pno': Posz p != 0) by lia.
- apply icoef_maxnorm_div_round_leq; trivial.
- rewrite mulnC.
- eapply leq_trans.
- apply icoef_maxnorm_triang.
- rewrite icoef_maxnorm_neg.
- assert (icoef_maxnorm (x * b) <= (size b * icoef_maxnorm b) * p./2).
- {
- generalize (icoef_maxnorm_mul b x); intros.
- rewrite mulrC in H3.
- eapply leq_trans.
- apply H3.
- apply leq_mul; lia.
- }
- rewrite addnC.
- eapply leq_trans.
- - apply leq_add.
- apply H3.
- apply H2.
- - assert (size b * icoef_maxnorm b * p./2 + p./2 <= (size b * icoef_maxnorm b + 1) * p./2) by lia.
- eapply leq_trans.
- apply H4.
- eapply leq_trans.
- apply mul_half4; trivial.
- apply leq_mul; lia.
-Qed.
-
-Lemma nearest_round_int_modz (n1 n2: int) (p1 p2 : nat) :
- 0 < p1 ->
- 1 < p2 ->
- (n1%:~R : 'Z_(p1 * p2)) = (n2%:~R : 'Z_(p1 * p2)) ->
- ((nearest_round_int n1 p1) %:~R : 'Z_p2) = ((nearest_round_int n2 p1) %:~R : 'Z_p2).
-Proof.
- intros.
- assert (Posz p1 <> 0) by lia.
- rewrite -int_Zp_eq_iff //.
- apply nearest_round_int_mod; trivial.
- rewrite int_Zp_eq_iff //.
- lia.
-Qed.
-
-Lemma div_round_mod (n1 n2 : {poly int}) (p1 p2 : nat) :
- 0 < p1 ->
- 1 < p2 ->
- q_reduce (p1 * p2) n1 = q_reduce (p1 * p2) n2 ->
- q_reduce p2 (div_round n1 p1) = q_reduce p2 (div_round n2 p1).
-Proof.
- rewrite /q_reduce /div_round !map_polyE.
- intros.
- rewrite -polyP => i.
- rewrite !coef_Poly !(nth_map_default 0 0); try by rewrite rmorph0.
- rewrite -polyP in H1.
- specialize (H1 i).
- rewrite !coef_Poly !(nth_map_default 0 0) in H1; try by rewrite rmorph0.
- rewrite !coef_Poly !(nth_map_default 0 0); try by rewrite rmorph0.
- - apply nearest_round_int_modz; trivial.
- - by rewrite nearest_round_int0.
- - by rewrite nearest_round_int0.
-Qed.
-
-Lemma lineariz_prop_div3 {q p : nat} (qbig : 1 < q) (pbig : 1 < p) (oddp : odd p) (c2 : {poly 'Z_q})
- (s e : {poly int}) (a : {poly 'Z_(p*q)}) :
- let c2' := q_reduce (p*q) (zlift c2) in
- { e2 : {poly 'Z_q} |
- q_reduce q (div_round (zlift ((map_poly (fun P => c2' * P) (ev_key s e a)).[q_reduce (p * q) s])) p) =
- (map_poly (fun P => q_reduce q (div_round ((zlift c2) * (zlift P)) (p%:Z)))
- (ev_key s e a)).[q_reduce q s] + e2}.
-Proof.
- assert (pqbig: (1 < p * q)) by lia.
- assert (pno: (Posz p <> 0)) by lia.
- destruct (div_round_mul_ex p (zlift c2 * zlift a) s oddp) as [?[??]].
- generalize (div_round_add2_perturb_small (zlift c2 * zlift a * s)
- (zlift c2 * zlift (- a * q_reduce (p * q) s + q_reduce (p * q) e + q_reduce (p * q) (s ^+ 2 *+ p))) p pno); intros peterb_small.
- exists
- (q_reduce q ( (div_round_add2_perturb (zlift c2 * zlift a * s)
- (zlift c2 * zlift (- a * q_reduce (p * q) s + q_reduce (p * q) e + q_reduce (p * q) (s ^+ 2 *+ p))) p pno -
- x))).
- rewrite !map_Poly_id0.
- + rewrite !horner_Poly /= !mul0r !add0r.
- rewrite -!rmorphM -rmorphD /=.
- rewrite H -(addrA _ x) (addrC x _) addrA.
- rewrite div_round_add2_eq_alt.
- assert (q_reduce q
- (div_round
- (zlift
- (q_reduce (p * q) (zlift c2) * a * q_reduce (p * q) s +
- q_reduce (p * q) (zlift c2) *
- (- a * q_reduce (p * q) s + q_reduce (p * q) e + q_reduce (p * q) (s ^+ 2 *+ p)))) p) =
- q_reduce q
- (div_round
- (zlift c2 * zlift a * s +
- zlift c2 * zlift (- a * q_reduce (p * q) s + q_reduce (p * q) e + q_reduce (p * q) (s ^+ 2 *+ p))) p)).
- {
- apply div_round_mod; try lia.
- rewrite rmorphD /= zlift_valid; try lia.
- f_equal.
- - by rewrite !rmorphM /= !zlift_valid; try lia.
- - by rewrite rmorphM /= zlift_valid; try lia.
- }
- rewrite H1 !rmorphD !rmorphN /= -[LHS]addr0 -!addrA.
- f_equal.
- ring.
- + by rewrite (zlift0 0) // mulr0 div_round0 rmorph0.
- + by rewrite mulr0.
-Qed.
-
-Lemma linearize_prop {q p : nat} (qbig : 1 < q) (pbig : 1 < p) (c2 : {poly 'Z_q}) (s e : {poly int}) (a : {poly 'Z_(p*q)}) :
- { e2 : {poly int} |
- (map_poly (fun P => q_reduce q (div_round ((zlift c2) * (zlift P)) (p%:Z)))
- (ev_key s e a)).[q_reduce q s] =
- c2 * (q_reduce q (exp s 2)) + q_reduce q (div_round ((zlift c2) * e) p + e2) /\
- icoef_maxnorm e2 <= 1}.
-Proof.
- assert (pqbig: (1 < p * q)) by lia.
- assert (pno: (Posz p <> 0)) by lia.
- eexists; split.
- - rewrite /ev_key /key_switch_key.
- rewrite map_Poly_id0.
- + rewrite horner_Poly /= mul0r add0r.
- rewrite !(zlift_add2_eq,mulrDr) rmorphMn /=.
- rewrite div_round_add2_eq.
- rewrite lift_reduce_prod2 // mulrnAr /=.
- rewrite div_round_muln_add; try lia.
- rewrite !rmorphD /=.
- rewrite rmorphM /=.
- rewrite !zlift_valid //.
- rewrite !div_round_add2_eq !rmorphD /=.
- admit.
- + by rewrite [zlift 0]zlift0 // mulr0 div_round0 rmorph0.
- - admit.
-Admitted.
-
-Definition rescale {q1 q2 : nat} (p : {poly 'Z_(q1 * q2)}) : {poly 'Z_q2} :=
- q_reduce q2 (div_round (zlift p) q1%:Z).
-
-Definition FHE_mult {q p : nat} (P Q : {poly {poly 'Z_q}})
- (evkey : {poly {poly 'Z_(p*q)}}) :=
- let PQ := FHE_mult_base P Q in
- linearize (PQ`_0) (PQ`_1) (PQ`_2) evkey.
-
-Lemma decrypt_mult {p q : nat} (P Q : {poly 'Z_q}) (PP QQ : {poly {poly 'Z_q}})
- (s e : {poly int}) (a : {poly 'Z_(p*q)}) :
- FHE_decrypt s PP = P ->
- FHE_decrypt s QQ = Q ->
- size PP = 2%N ->
- size QQ = 2%N ->
- {R : {poly int} |
- FHE_decrypt s (FHE_mult PP QQ (ev_key s e a)) =
- P * Q + q_reduce q (div_round (zlift (PP * QQ)`_2 * e) p + R) /\
- icoef_maxnorm R <= 1 }.
-Proof.
-(*
- intros.
- rewrite -(decrypt_mult_base P Q PP QQ s) //.
- rewrite /FHE_mult /linearize /FHE_mult_base /FHE_decrypt hornerD.
- rewrite linearize_prop.
- assert (size (PP * QQ) <= 3%N).
- {
- generalize (size_mul_leq PP QQ); intros.
- by rewrite H1 H2 in H3.
- }
- replace (q_reduce q (s ^+ 2)) with
- ((q_reduce q s) ^+ 2).
- - assert (PP * QQ = Poly [:: (PP * QQ)`_0; (PP * QQ)`_1; (PP * QQ)`_2]).
- {
- replace (PP * QQ) with (take_poly 3 (PP * QQ)) at 1.
- - unfold take_poly.
- rewrite -polyP.
- intros ?.
- rewrite coef_poly coef_Poly.
- case: leqP => ineq.
- + by rewrite nth_default.
- + by rewrite -(nth_mkseq 0 (fun i => (PP * QQ)`_i) ineq).
- - rewrite take_poly_id //.
- }
- rewrite {5}H4 !horner_Poly /= mul0r !add0r mulrDl addrC -!addrA.
- f_equal.
- + by rewrite expr2 mulrA.
- + by rewrite addrC addrA.
- - by rewrite rmorphXn.
-Qed.
- *)
- Admitted.
-
-Definition key_switch {q p : nat} (c0 c1 : {poly 'Z_q})
- (ks_key : {poly {poly 'Z_(p*q)}}) : {poly {poly 'Z_q}} :=
- c0%:P + map_poly (fun P => q_reduce q (div_round ((zlift c1) * (zlift P)) (p%:Z)))
- ks_key.
-
-Definition FHE_automorphism {q p : nat} (s e : {poly int})
- (a : {poly 'Z_(p*q)}) (P : {poly {poly 'Z_q}}) (j : nat) :=
- key_switch (comp_poly 'X^(2*j+1) P`_0)
- (comp_poly 'X^(2*j+1) P`_1)
- (key_switch_key s (comp_poly 'X^(2*j+1) s) e a).
-
-Lemma decrypt_FHE_automorphism_base {q p : nat} (s : {poly int}) (P : {poly 'Z_q})
- (PP : {poly {poly 'Z_q}}) (j : nat) :
- FHE_decrypt s PP = P ->
- comp_poly 'X^(2*j+1) P = FHE_decrypt (comp_poly 'X^(2*j+1) s)
- (map_poly (comp_poly 'X^(2*j+1)) PP).
-Proof.
- rewrite /FHE_decrypt.
- intros.
- replace (q_reduce q (s \Po 'X^(2 * j + 1))) with
- (comp_poly 'X^(2*j+1) (q_reduce q s)).
- - by rewrite horner_map /= H.
- - rewrite /q_reduce map_comp_poly /=.
- f_equal.
- by rewrite map_polyXn.
-Qed.
-
diff --git a/coq/FHE/encode.v b/coq/FHE/encode.v
deleted file mode 100644
index 06184a80..00000000
--- a/coq/FHE/encode.v
+++ /dev/null
@@ -1,5286 +0,0 @@
-Require Import Reals Lra Lia List Permutation.
-From mathcomp Require Import common ssreflect fintype bigop ssrnat matrix Rstruct complex seq fingroup.
-From mathcomp Require Import ssralg ssrfun.
-From mathcomp Require Import generic_quotient ring_quotient.
-From mathcomp Require Import poly mxpoly polydiv ssrint zmodp eqtype ssrbool div.
-
-Import ssralg.GRing.
-Require Import nth_root.
-
-Ltac coq_lra := lra.
-From mathcomp Require Import lra.
-
-Set Bullet Behavior "Strict Subproofs".
-
-
-Lemma INR0eq (n:nat) : (INR n == 0%R) = (n == 0%nat).
-Proof.
- apply/eqP.
- case: eqP.
- - by move=> ->.
- - by apply not_0_INR.
-Qed.
-
-Lemma natmul0eq (n : nat) : ((n%:R)%R == 0%R :> R) = (n == 0%nat).
-Proof.
- by rewrite -INRE INR0eq.
-Qed.
-
-Lemma modulo_modn (a b : nat) :
- Nat.modulo a b = modn a b.
-Proof.
- case: b => [| b].
- - by rewrite modn0.
- - rewrite modn_def /Nat.modulo.
- move: (Nat.divmod_spec a b 0 b (le_refl _)).
- case: Nat.divmod => q u [].
- rewrite Nat.mul_0_r Nat.sub_diag !Nat.add_0_r /= => eqq _.
- rewrite (_:(edivn a b.+1) = (q, (b - u)%coq_nat)) // -(@edivn_eq b.+1); try lia.
- rewrite eqq.
- f_equal; lia.
-Qed.
-
-
-Local Open Scope ring_scope.
-
-
-Section construct.
-
- Import Ring ComRing UnitRing Pdiv.IdomainDefs Pdiv.WeakIdomain.
-
- Lemma size_poly1P_w [R : Exports.ringType] (p : {poly R}) :
- size p == 1%nat -> {c | c != 0 & p = c%:P}.
- Proof.
- move/eqP=> pC.
- have def_p: p = (p`_0)%:P
- by rewrite -size1_polyC ?pC.
- by exists p`_0; rewrite // -polyC_eq0 -def_p -size_poly_eq0 pC.
- Qed.
-
- Lemma eqpP_w [R : idomainType] (m n : {poly R}) :
- (m %= n) ->
- {c12 | (c12.1 != 0) && (c12.2 != 0) & c12.1 *: m = c12.2 *: n}.
- Proof.
- case: (eqVneq m 0) => [-> /andP [/dvd0pP -> _] | m_nz].
- { by exists (1, 1); rewrite ?scaler0 // oner_eq0. }
- case: (eqVneq n 0) => [-> /andP [_ /dvd0pP ->] | n_nz /andP []].
- { by exists (1, 1); rewrite ?scaler0 // oner_eq0. }
- rewrite !dvdp_eq; set c1 := _ ^+ _; set c2 := _ ^+ _.
- set q1 := _ %/ _; set q2 := _ %/ _; move/eqP => Hq1 /eqP Hq2;
- have Hc1 : c1 != 0 by rewrite expf_eq0 lead_coef_eq0 negb_and m_nz orbT.
- have Hc2 : c2 != 0 by rewrite expf_eq0 lead_coef_eq0 negb_and n_nz orbT.
- have def_q12: q1 * q2 = (c1 * c2)%:P.
- apply: (mulIf m_nz); rewrite mulrAC mulrC -Hq1 -scalerAr -Hq2 scalerA.
- by rewrite -mul_polyC.
- have: q1 * q2 != 0 by rewrite def_q12 -size_poly_eq0 size_polyC mulf_neq0.
- rewrite mulf_eq0; case/norP=> nz_q1 nz_q2.
- have: size q2 <= 1.
- have:= size_mul nz_q1 nz_q2; rewrite def_q12 size_polyC mulf_neq0 //=.
- by rewrite polySpred // => ->; rewrite leq_addl.
- rewrite leq_eqVlt ltnS size_poly_leq0 (negPf nz_q2) orbF.
- case/size_poly1P_w=> c cn0 cqe; exists (c2, c); first by rewrite Hc2.
- by rewrite Hq2 -mul_polyC -cqe.
- Qed.
-
- Lemma Bezoutp_w [R : idomainType] (p q : {poly R}) : {u | u.1 * p + u.2 * q %= (gcdp p q)}.
- Proof.
- have [-> | pn0] := eqVneq p 0.
- - by rewrite gcd0p; exists (0, 1); rewrite mul0r mul1r add0r eqpxx.
- - have [-> | qn0] := eqVneq q 0.
- + by rewrite gcdp0; exists (1, 0); rewrite mul0r mul1r addr0 eqpxx.
- + pose e := egcdp p q; exists e; rewrite eqp_sym.
- by case: (egcdpP pn0 qn0).
- Qed.
-
- Lemma Bezout_coprimepP_w [R : idomainType] (p q : {poly R}) :
- coprimep p q ->
- {u | u.1 * p + u.2 * q %= 1}.
- Proof.
- rewrite -gcdp_eqp1.
- case: (Bezoutp_w p q) => [[u v] Puv].
- intros g1.
- by exists (u, v); apply: eqp_trans g1.
- Qed.
-
- Lemma Bezout_eq1_coprimepP_w [F : fieldType] (p q : {poly F}) :
- coprimep (R:=F) p q ->
- {u : poly_ringType F * poly_ringType F | u.1 * p + u.2 * q = 1}.
- Proof.
- move/ Bezout_coprimepP_w.
- case=> [[u v]] /=.
- case/eqpP_w=> [[c1 c2]] /andP /= [c1n0 c2n0] e.
- exists (c2^-1 *: (c1 *: u), c2^-1 *: (c1 *: v)); rewrite /= -!scalerAl.
- by rewrite -!scalerDr e scalerA mulVf // scale1r.
- Qed.
-
-End construct.
-
- Section Chinese.
-
-(* The chinese remainder theorem for polynomials overa a field *)
-Variables F : fieldType.
-
-
-Definition chinesep (m1 m2 r1 r2 : {poly F}) (co_m12: coprimep (R:=F) m1 m2) :=
- let u := sval (Bezout_eq1_coprimepP_w m1 m2 co_m12) in
- r1 * m2 * u.1 + r2 * m1 * u.2.
-
- Lemma chinesep_prop (m1 m2 r1 r2 : {poly F}) :
- coprimep (R:=F) m1 m2 ->
- {e : poly_ringType F |
- e %% m1 = r1 %% m1 /\ e %% m2 = r2 %% m2}.
- Proof.
- intros co_m12.
- destruct (Bezout_eq1_coprimepP_w m1 m2 co_m12).
- pose c := r2 * x.1 * m1 + r1 * x.2 * m2.
- exists c.
- split.
- - rewrite modpD modp_mull add0r.
- apply (f_equal (fun z => z %% m1)) in e.
- rewrite modpD modp_mull add0r in e.
- by rewrite -mulrA -modp_mul e modp_mul mulr1.
- - rewrite modpD modp_mull addr0.
- apply (f_equal (fun z => z %% m2)) in e.
- rewrite modpD modp_mull addr0 in e.
- by rewrite -mulrA -modp_mul e modp_mul mulr1.
-Qed.
-
- Lemma all_coprimep_prod (a : {poly F}) (l : seq {poly F}) :
- all (coprimep a) l ->
- coprimep a (\prod_(i <- l) i).
- Proof.
- intros.
- rewrite big_seq.
- apply big_rec.
- - apply coprimep1.
- - intros.
- move/allP/(_ _ H0): H.
- by rewrite coprimepMr H1 => ->.
- Qed.
-
- Lemma mod_mul_mod_r (a p q : {poly F}) :
- (a %% (p * q))%%q = a%%q.
- Proof.
- generalize (divp_eq a (p * q)); intros.
- apply (f_equal (fun z => z %% q)) in H.
- rewrite H.
- by rewrite modpD mulrA modp_mull add0r.
- Qed.
-
- Lemma mod_mul_mod_l (a p q : {poly F}) :
- (a %% (p * q))%%p = a%%p.
- Proof.
- rewrite mulrC.
- apply mod_mul_mod_r.
- Qed.
-
- Lemma prod_mod (a b : {poly F}) (l : seq {poly F}) :
- a %% (\prod_(q <- l) q) = b %% (\prod_(q <- l) q) ->
- forall p,
- p \in l -> a %% p = b %% p.
- Proof.
- induction l.
- - intros.
- by rewrite in_nil in H0.
- - intros.
- rewrite in_cons in H0.
- move /orP in H0.
- rewrite !big_cons in H.
- destruct H0.
- + rewrite (eqP H0).
- simpl in H.
- apply (f_equal (fun z => z %% a0)) in H.
- by rewrite !mod_mul_mod_l in H.
- + apply (f_equal (fun z => z %% (\prod_(j <- l) j))) in H.
- rewrite !mod_mul_mod_r in H.
- specialize (IHl H).
- by apply IHl.
- Qed.
-
-Lemma chinesep_list_prop (l : seq ({poly F} * {poly F})) :
- pairwise (coprimep (R := F)) (map snd l) ->
- { e : {poly F} |
- forall p,
- p \in l -> e %% p.2 = p.1 %% p.2}.
-Proof.
- induction l.
- - simpl.
- intros.
- exists 0.
- intros.
- by rewrite in_nil in H0.
- - intros.
- rewrite map_cons pairwise_cons in H.
- move /andP in H.
- destruct H.
- specialize (IHl H0).
- destruct IHl.
- assert (coprimep a.2 (\prod_(q <- l) q.2)).
- {
- generalize (all_coprimep_prod a.2 [seq i.2 | i <- l] H); intros.
- by rewrite big_map in H1.
- }
- destruct (chinesep_prop a.2 (\prod_(q <- l) q.2) a.1 x H1) as [? [??]].
- exists x0.
- intros.
- rewrite in_cons in H4.
- move /orP in H4.
- destruct H4.
- + by rewrite (eqP H4).
- + specialize (e p H4).
- rewrite -e.
- assert (\prod_(q <- l) q.2 = \prod_(q <- [seq i.2 | i <- l]) q).
- {
- by rewrite big_map.
- }
- rewrite H5 in H3.
- generalize (prod_mod x0 x [seq i.2 | i <- l] H3 p.2); intros.
- apply H6.
- by apply map_f.
- Qed.
-
-End Chinese.
-
-
-Definition odd_nth_roots (n : nat) :=
- \row_(j < 2^n) (nth_root (2 * j + 1) (2 ^ (S n))).
-
-Definition even_nth_roots (n : nat) :=
- \row_(j < 2^n) (nth_root (2 * j) (2 ^ (S n))).
-
-Definition nth_roots_half (n : nat) :=
- \row_(j < 2^n) (nth_root j (2 ^ (S n))).
-
-Lemma unity_root_nth_root (j n : nat) :
- n.+1.-unity_root (nth_root j n.+1).
-Proof.
- apply /unity_rootP.
- by rewrite nth_root_npow /RtoC /=.
-Qed.
-
-Lemma primitive_root_nth_root (n : nat) :
- n.+1.-primitive_root (nth_root 1 n.+1).
-Proof.
- intros.
- rewrite /primitive_root_of_unity.
- apply/andP.
- split; try lia.
- apply /forallP; intros.
- apply /eqP.
- apply /unity_rootP.
- case: (eqVneq x.+1 n.+1); intros.
- - by rewrite e nth_root_npow /RtoC /=.
- - rewrite exp_nth_root muln1.
- apply nth_root_not_1.
- rewrite Nat.mod_small; try lia.
- destruct x; simpl in *.
- lia.
-Qed.
-
-Lemma primitive_root_nth_root_coprime (j n : nat) :
- coprime j n.+1 ->
- n.+1.-primitive_root ((nth_root 1 n.+1) ^+ j).
-Proof.
- intros.
- rewrite prim_root_exp_coprime //.
- apply primitive_root_nth_root.
-Qed.
-
-Lemma primitive_root_nth_root_coprime_alt (j n : nat) :
- coprime j n.+1 ->
- n.+1.-primitive_root (nth_root j n.+1).
-Proof.
- intros.
- generalize (primitive_root_nth_root_coprime j n H); intros.
- by rewrite exp_nth_root muln1 in H0.
-Qed.
-
-Lemma pow2_S (j:nat) :
- { k : nat | (2^j)%nat == S k}.
-Proof.
- exists (2^j-1)%nat.
- induction j.
- - now simpl.
- - rewrite /= expnS (eqP IHj); lia.
-Defined.
-
-Lemma primitive_root_odd_nth_root (j n : nat) :
- (2^(n.+1)).-primitive_root ((nth_root 1 (2^n.+1)) ^+ (2 * j + 1)).
-Proof.
- destruct (pow2_S (S n)).
- rewrite (eqP i).
- apply primitive_root_nth_root_coprime.
- rewrite -(eqP i); clear i.
- assert (forall j0, coprime (2 * j0 + 1) 2).
- {
- intros.
- rewrite coprimen2.
- induction j0.
- * by rewrite muln0 add0n /=.
- * replace (2 * j0.+1 + 1)%N with (2 + (2 * j0 + 1))%N by lia.
- by rewrite oddD IHj0 /=.
- }
- induction n.
- - by rewrite expn1 H.
- - by rewrite expnS coprimeMr IHn // H.
-Qed.
-
-Lemma primitive_root_odd_nth_root_alt (j n : nat) :
- (2^(n.+1)).-primitive_root (nth_root (2 * j + 1) (2^n.+1)).
-Proof.
- generalize (primitive_root_odd_nth_root j n); intros.
- destruct (pow2_S (S n)).
- rewrite (eqP i).
- by rewrite (eqP i) exp_nth_root muln1 in H.
-Qed.
-
-Lemma pow2_nth_root_pow1 (n j : nat) :
- forall (e : nat),
- (nth_root j (2^n.+1)) ^+ e = 1 <->
- (e * j) mod (2^n.+1) = 0%N.
-Proof.
- intros.
- destruct (pow2_S n.+1).
- by rewrite (eqP i) exp_nth_root nth_root_1_iff.
-Qed.
-
-Lemma pow2_nth_root_pow_eq (n j : nat) :
- forall (e1 e2 : nat),
- (nth_root j (2^n.+1)) ^+ e1 =
- (nth_root j (2^n.+1)) ^+ e2 <->
- (e1 * j) = (e2 * j) %[mod (2^n.+1)].
-Proof.
- intros.
- rewrite -!modulo_modn.
- apply nth_root_pow_eq.
- lia.
-Qed.
-
-Lemma mul_INR n m :
- INR(n * m) = INR n * INR m.
-Proof.
- by rewrite mult_INR /mul /=.
-Qed.
-
-Lemma even_nth_root_half (j n : nat) :
- 0 < n ->
- nth_root (2 * j) (2 * n) = nth_root j n.
-Proof.
- intros.
- rewrite /nth_root.
- apply /eqP.
- rewrite eq_complex /=.
- apply /andP.
- assert (INR 2 != 0).
- {
- rewrite /zero/=.
- apply/eqP.
- coq_lra.
- }
- assert (INR 2 \is a unit).
- {
- by rewrite unitfE.
- }
- assert (INR n \is a unit).
- {
- rewrite unitfE INR0eq; lia.
- }
- assert (inv (INR 2) * (INR 2) = 1).
- {
- rewrite mulrC divff //.
- }
- split; apply /eqP; f_equal;
- rewrite -![2 * PI * _ * _]mulrA; f_equal;
- rewrite !mul_INR invrM // [INR 2 * _]mulrC -mulrA; f_equal;
- by rewrite mulrC -mulrA H3 mulr1.
-Qed.
-
-Lemma even_nth_root_half_pow (j n : nat) :
- nth_root (2 * j) (2 ^ (S n)) = nth_root j (2^n).
-Proof.
- destruct (pow2_S n).
- rewrite expnS (eqP i).
- apply even_nth_root_half; lia.
-Qed.
-
-Lemma lt_0_1 :
- is_true (0 < 1).
-Proof.
- easy.
-Qed.
-
-Lemma poly_size_eq0 {R:ringType} (p:{poly R}) :
- seq.size p == 0%nat = (p == 0).
-Proof.
- rewrite -size_poly_leq0.
- lia.
-Qed.
-
-Definition peval_mat {n} (roots : 'rV[R[i]]_n) : 'M[R[i]]_(n,n) :=
- \matrix_(i < n, j < n) (exp (roots 0 i) j).
-
-Definition conj_mat {n1 n2} (m : 'M[R[i]]_(n1,n2)) :=
- map_mx conjc m.
-
-Definition Vscale {n} (c : R[i]) (v : 'rV[R[i]]_n) :=
- c *: v.
-
-Definition vector_sum {n} (v : 'rV[R[i]]_n) :=
- \sum_(j < n) (v 0 j).
-
-Definition inner_prod {n} (v1 v2 : 'rV[R[i]]_n) :=
- (v1 *m (v2^T)) 0 0.
-
-Definition H_inner_prod {n} (v1 v2 : 'rV[R[i]]_n) :=
- inner_prod v1 (conj_mat v2).
-
-Lemma vector_sum_scale {n} (c : R[i]) (v : 'rV[R[i]]_n) :
- mul c (vector_sum v) = vector_sum (Vscale c v).
-Proof.
- unfold vector_sum.
- unfold Vscale.
- rewrite Theory.mulr_sumr.
- erewrite eq_big_seq; [reflexivity |].
- simpl.
- apply ssrbool.in1W; intros.
- now rewrite mxE.
-Qed.
-
-Definition ConstVector n (c : R[i]) : 'rV[R[i]]_n:= const_mx c.
-
-Definition RtoC (x : R):R[i] := Complex x 0.
-
-(* Coercion RtoC : R >-> complex. *)
-
-Lemma vector_sum_const {n} (c : R[i]) :
- vector_sum (ConstVector n c) = mul (n%:R) c.
-Proof.
- rewrite /vector_sum /ConstVector.
- (under eq_big_seq => - do (apply ssrbool.in1W => ?; rewrite mxE)).
- rewrite big_const_ord iter_addr_0 Theory.mulr_natl//.
-Qed.
-
-Lemma conj_transpose {n} (m : 'M[R[i]]_(n,n)) :
- conj_mat (m^T) = (conj_mat m)^T.
-Proof.
- now rewrite map_trmx.
-Qed.
-
-Lemma RtoCnat_eq n : RtoC (INR n) = n%:R.
-Proof.
- unfold RtoC.
- induction n.
- - now rewrite Theory.mulr0n.
- - rewrite Theory.mulrSr S_INR -IHn /add/= add0r//.
-Qed.
-
-(* testing notations *)
-Definition C0': R[i] := 0.
-Definition C1': R[i] := 1.
-Definition Cplus' (x y : R[i]) := x + y.
-Definition Cmult' (x y : R[i]) := x * y.
-Definition Cexp' (x : R[i]) (n : nat) := x ^+ n.
-Definition Cdiv' (x y : R[i]) := x / y.
-Definition Cinv' (x : R[i]) := x^-1.
-
-Lemma peval_row (n : nat) :
- forall n0,
- row n0 (peval_mat (odd_nth_roots (S n))) =
- \row_(j < 2^(S n)) (odd_nth_roots (S n) 0 n0) ^+ j.
-Proof.
- intros.
- unfold row.
- simpl.
- unfold peval_mat.
- apply eq_mx; intros ??.
- now rewrite mxE.
-Qed.
-
-Lemma pow_nth_root j n e :
- (nth_root j (S n)) ^+ e = nth_root (e * j) (S n).
-Proof.
- apply exp_nth_root.
-Qed.
-
-Lemma pow_nth_root' j n e :
- n != 0%nat ->
- (nth_root j n) ^+ e = nth_root (e * j) n.
-Proof.
- destruct n; [lia |]=>_.
- apply pow_nth_root.
-Qed.
-
-Lemma odd_nth_roots_pow (n : nat) (c : R[i]) :
- c ^+ (2 ^ (S n)) = -1 ->
- forall j,
- (c ^+ (2 * j + 1)) ^+ (2 ^ (S n)) = -1.
-Proof.
- intros.
- by rewrite exprAC -exprM exprM H exprD exprM expr1 expr2 mulrNN mulr1 expr1n mul1r.
-Qed.
-
-Lemma odd_nth_roots_powk (n : nat) (c : R[i]) :
- c ^+ (2 ^ (S n)) = -1 ->
- forall j k,
- (c ^+ ((2 * j + 1)^k)) ^+ (2 ^ (S n)) = -1.
-Proof.
- intros.
- rewrite -exprM mulnC exprM H.
- induction k.
- - by rewrite expn0 expr1.
- - by rewrite expnS mulnC exprM IHk exprD exprM expr1 expr2 mulrNN mulr1 expr1n mul1r.
-Qed.
-
-Lemma odd_nth_roots_conj (n : nat) (c : R[i]) :
- c ^+ (2 ^ n.+1) = -1 ->
- (conjc c) ^+ (2^n.+1) = -1.
-Proof.
- intros.
- by rewrite -rmorphXn /= H rmorphN1.
-Qed.
-
-Lemma decode_encode_on_diag (n : nat):
- let pmat := (peval_mat (odd_nth_roots (S n))) in
- forall n0,
- H_inner_prod (row n0 pmat) (row n0 pmat) = (2^S n)%:R.
-Proof.
- intros.
- unfold H_inner_prod, inner_prod, pmat.
- rewrite mxE.
- under eq_big_seq.
- - apply ssrbool.in1W; intros.
- rewrite peval_row /odd_nth_roots !mxE.
- destruct (pow2_S (S (S n))).
- rewrite (eqP i) pow_nth_root mult_conj_root.
- over.
- - rewrite big_const_ord iter_addr_0//.
-Qed.
-
-Lemma H_inner_prod_mat n (M : 'M[R[i]]_(n,n)) :
- forall i j,
- (M *m (conj_mat (M ^T))) i j =
- H_inner_prod (row i M) (row j M).
-Proof.
- rewrite /H_inner_prod /inner_prod => i j.
- rewrite !mxE //=.
- apply eq_big_seq => ??.
- rewrite !mxE//.
-Qed.
-
-Lemma telescope_mult_bigop_aux (c : R[i]) (n : nat) :
- (c - 1) * (\sum_(0 <= j < S n) (c ^+ j)) =
- \sum_(0 <= j < S n) ((c^+(S j)) - (c ^+ j)).
-Proof.
- rewrite big_distrr.
- simpl.
- apply eq_big_seq; intros ??.
- rewrite mulrBl.
- rewrite mul1r.
- f_equal.
- rewrite exprSr.
- now rewrite mulrC.
-Qed.
-
-Lemma telescope_mult_bigop (c : R[i]) (n : nat) :
- (c - 1) * (\sum_(0 <= j < S n) (c ^+ j)) =
- c ^+ (S n) - 1.
-Proof.
- rewrite telescope_mult_bigop_aux.
- rewrite telescope_sumr.
- + now rewrite expr0.
- + lia.
-Qed.
-
-Lemma telescope_div (c : R[i]) (n : nat) :
- c <> 1 ->
- \sum_(0 <= j < S n) (c ^+ j) =
- (c ^+ (S n) - 1) / (c - 1).
-Proof.
- intros.
- generalize (telescope_mult_bigop c n); intros.
- rewrite -H0 mulrC mulrA Cinv_l.
- - now rewrite mul1r.
- - unfold not.
- intros.
- clear H0.
- replace C0 with (zero C) in H1 by reflexivity.
- apply (f_equal (fun cc => add cc 1)) in H1.
- by rewrite add0r -addrA (addrC _ 1) addrN addr0 in H1.
-Qed.
-
-Lemma telescope_pow_0_nat (c : R[i]) (n : nat) :
- c <> 1 ->
- c ^+ (S n) = 1 ->
- \sum_(0 <= j < S n) (c ^+ j) = C0.
-Proof.
- intros.
- rewrite telescope_div; trivial.
- by rewrite H0 addrN mul0r.
-Qed.
-
-Lemma telescope_pow_0_ord (c : R[i]) (n : nat) :
- c <> 1 ->
- c ^+ (S n) = 1 ->
- \sum_(j < S n) (c ^+ j) = C0.
-Proof.
- intros.
- rewrite <- (telescope_pow_0_nat c n); trivial.
- by rewrite /= big_mkord.
-Qed.
-
-Lemma add_conj (c1 c2 : R[i]) :
- (conjc c1) + (conjc c2) = conjc (c1 + c2).
-Proof.
- by rewrite rmorphD.
-Qed.
-
-Lemma mul_conj (c1 c2 : R[i]) :
- (conjc c1) * (conjc c2) = conjc (c1 * c2).
-Proof.
- by rewrite rmorphM.
-Qed.
-
-Lemma exp_conj (c : R[i]) n :
- conjc (c ^+ n) = (conjc c)^+n.
-Proof.
- by rewrite rmorphXn.
-Qed.
-
-Lemma decode_encode_off_diag (n : nat):
- let pmat := (peval_mat (odd_nth_roots (S n))) in
- forall n1 n2,
- n1 <> n2 ->
- H_inner_prod (row n1 pmat) (row n2 pmat) = C0.
-Proof.
- intros.
- unfold H_inner_prod, inner_prod.
- unfold pmat, peval_mat.
- rewrite mxE.
- simpl.
- destruct (pow2_S (S n)).
- unfold odd_nth_roots.
- generalize (telescope_pow_0_ord ((nth_root (2*n1+1) (2^S(S n))) *
- (conjc (nth_root (2*n2+1) (2^S(S n))))) x); intros.
- rewrite <- H0.
- - rewrite <- (eqP i).
- erewrite eq_big_seq; [reflexivity |].
- simpl.
- apply ssrbool.in1W; intros.
- rewrite !mxE exprMn_comm.
- + rewrite exp_conj//.
- + rewrite /comm mulrC//.
- - unfold not; intros.
- destruct (pow2_S (S (S n))).
- rewrite (eqP i0) nth_root_conj_alt nth_root_mul in H1.
- apply nth_root_1_iff in H1.
- rewrite -(eqP i0) in H1.
- clear H0 i i0 pmat x x0.
- destruct n1 as [x xlt]; destruct n2 as [y ylt]; simpl in *.
- assert (neq:x <> y).
- {
- intros HH.
- apply H; subst.
- f_equal; apply eqtype.bool_irrelevance.
- }
- clear H.
- rewrite !modulo_modn in H1.
- apply (f_equal ssrint.Posz) in H1.
- revert H1.
- rewrite -intdiv.modz_nat ssrint.PoszD -ssrint.subzn.
- + rewrite -intdiv.modz_nat.
- rewrite -intdiv.modzDm.
- rewrite !addn1 intdiv.modzDl intdiv.modzNm.
- rewrite !intdiv.modzDm expnSr.
- destruct (@leP x y).
- * rewrite -intdiv.modzDl intdiv.modz_small/=; lia.
- * rewrite intdiv.modz_small/=; lia.
- + lia.
- - destruct (pow2_S (S (S n))).
- rewrite (eqP i0) nth_root_conj_alt nth_root_mul exp_nth_root.
- apply nth_root_1_iff.
- rewrite -(eqP i0) -(eqP i).
- clear H0 i i0 pmat x x0.
- destruct n1 as [x xlt]; destruct n2 as [y ylt]; simpl in *.
- assert (neq:x <> y).
- {
- intros HH.
- apply H; subst.
- f_equal; apply eqtype.bool_irrelevance.
- }
- clear H.
- rewrite !modulo_modn.
- replace (expn 2 (S (S n))) with (expn 2 (S n) * 2)%nat by (rewrite (expnS _ (S n)); lia).
- rewrite -div.muln_modr -div.modnDm.
- replace (2 * x)%nat with (x * 2)%nat by lia.
- rewrite div.modnMDl.
- replace (div.modn (2 ^ n.+1 * 2 - div.modn (2 * y + 1) (2 ^ n.+1 * 2)) 2) with
- (div.modn 1 2).
- + rewrite div.modnDm.
- replace (1 + 1)%nat with 2%nat by lia.
- rewrite div.modnn; lia.
- + replace ( div.modn (2 * y + 1) (2 ^ n.+1 * 2)) with (2 * y + 1)%nat.
- * rewrite div.modnB; try lia.
- * rewrite div.modn_small; lia.
- Qed.
-
-Lemma decode_encode_scalar_mx (n : nat):
- let pmat := (peval_mat (odd_nth_roots (S n))) in
- pmat *m (conj_mat (pmat^T)) = scalar_mx (2^S n)%:R.
-Proof.
- intros.
- apply matrixP; intros ??.
- do 2 rewrite mxE.
- destruct (@eqtype.eqP _ x y); intros.
- - rewrite e.
- simpl.
- generalize (decode_encode_on_diag n); intros.
- rewrite eqtype.eq_refl mulr1n.
- unfold pmat in *.
- specialize (H y).
- unfold H_inner_prod, inner_prod in H.
- rewrite mxE in H.
- rewrite <- H.
- simpl.
- erewrite eq_big_seq; [reflexivity |].
- apply ssrbool.in1W; intros.
- now repeat rewrite mxE.
- - simpl.
- generalize (decode_encode_off_diag n x y n0); intros.
- unfold H_inner_prod, inner_prod in H.
- rewrite mxE/=/row in H.
- (* I am sure there is a better way to do this *)
- repeat rewrite /eqtype.eq_op/=.
- destruct (@eqnP x y); simpl in *.
- + elim n0.
- now apply ord_inj.
- + replace ((2 ^ n.+1)%:R *+ 0) with C0 by reflexivity.
- rewrite <- H.
- erewrite eq_big_seq; [reflexivity |].
- apply ssrbool.in1W; intros.
- now repeat rewrite mxE.
-Qed.
-
-Lemma decode_mat_encode_mat (n : nat) (cl : 'cV[R[i]]_(2^(S n))) :
- let pmat := peval_mat (odd_nth_roots (S n)) in
- let encmat := conj_mat (pmat^T) in
- pmat *m (encmat *m cl) = (2^S n)%:R *: cl.
-Proof.
- simpl.
- rewrite mulmxA.
- generalize (decode_encode_scalar_mx n); intros.
- rewrite H.
- now rewrite mul_scalar_mx.
-Qed.
-
-(* shows evaluation can be done by modified FFT of half size*)
-Lemma peval_mat_prod (n : nat) :
- peval_mat (odd_nth_roots (S n)) =
- peval_mat (even_nth_roots (S n)) *m diag_mx (nth_roots_half (S n)).
-Proof.
- apply matrixP; intros ??.
- unfold nth_roots_half, even_nth_roots, peval_mat.
- rewrite mul_mx_diag.
- repeat rewrite mxE.
- destruct (pow2_S (S (S n))).
- rewrite (eqP i) !pow_nth_root nth_root_mul.
- f_equal.
- lia.
-Qed.
-
-(* shows enconding can be done by modified IFFT of half size*)
-Lemma encode_mat_prod (n : nat) :
- let pmat := peval_mat (odd_nth_roots (S n)) in
- let encmat := (conj_mat (pmat^T)) in
- encmat =
- diag_mx (map_mx conjc (nth_roots_half (S n)))
- *m
- peval_mat (map_mx conjc (even_nth_roots (S n))).
-Proof.
- apply matrixP; intros ??.
- unfold nth_roots_half, conj_mat, peval_mat, even_nth_roots.
- rewrite mul_diag_mx.
- repeat rewrite mxE.
- destruct (pow2_S (S (S n))).
- rewrite (eqP i) pow_nth_root -exp_conj mul_conj.
- f_equal.
- rewrite pow_nth_root nth_root_mul.
- f_equal.
- lia.
-Qed.
-
-Definition vector_rev {n} {T} (v : 'rV[T]_n) :=
- \row_(i < n) v 0 (rev_ord i).
-
-Definition vector_rev_conj {n} (v : 'rV[R[i]]_n) :=
- forall i,
- v 0 i = conjc (v 0 (rev_ord i)).
-
-Lemma vector_rev_conj_plus {n} (v1 v2 : 'rV[R[i]]_n) :
- vector_rev_conj v1 ->
- vector_rev_conj v2 ->
- vector_rev_conj (map2_mx (fun (c1 c2 : R[i]) => add c1 c2) v1 v2).
-Proof.
- unfold vector_rev_conj; intros.
- do 2 rewrite mxE.
- rewrite H.
- rewrite H0.
- now rewrite -rmorphD.
-Qed.
-
-Lemma vector_rev_conj_mult {n} (v1 v2 : 'rV[R[i]]_n) :
- vector_rev_conj v1 ->
- vector_rev_conj v2 ->
- vector_rev_conj (map2_mx (fun (c1 c2 : R[i]) => mul c1 c2) v1 v2).
-Proof.
- unfold vector_rev_conj; intros.
- do 2 rewrite mxE.
- rewrite H; rewrite H0.
- now rewrite -rmorphM.
-Qed.
-
-Lemma vector_rev_conj_scale {n} (r : R) (v : 'rV[R[i]]_n) :
- vector_rev_conj v ->
- vector_rev_conj (Vscale (RtoC r) v).
-Proof.
- unfold vector_rev_conj; intros.
- unfold Vscale.
- rewrite mxE.
- rewrite H.
- rewrite mxE.
- rewrite <- mul_conj.
- f_equal.
- unfold conjc, RtoC.
- apply f_equal.
- lra.
-Qed.
-
-Lemma vector_rev_conj_const_R n (r : R) :
- vector_rev_conj (ConstVector n (RtoC r)).
-Proof.
- unfold vector_rev_conj, ConstVector, RtoC; intros.
- do 2 rewrite mxE.
- unfold conjc.
- apply f_equal.
- lra.
-Qed.
-
-Lemma vector_rev_conj_conj {n} (v : 'rV[R[i]]_n) :
- vector_rev_conj v ->
- vector_rev_conj (map_mx conjc v).
-Proof.
- unfold vector_rev_conj; intros.
- do 2 rewrite mxE.
- now rewrite H.
-Qed.
-
-Lemma vector_rev_conj_exp {n} i (v : 'rV[R[i]]_n) :
- vector_rev_conj v ->
- vector_rev_conj (map_mx (fun c => exp c i) v).
-Proof.
- unfold vector_rev_conj; intros.
- do 2 rewrite mxE.
- rewrite H.
- now rewrite exp_conj.
-Qed.
-
-Lemma Cconj_im_0 (c : C) :
- conjc c = c -> Im c = 0%R.
-Proof.
- intros.
- destruct c.
- move /eqP in H.
- rewrite /conjc eq_complex /= in H.
- move /andP in H.
- destruct H.
- simpl.
- lra.
-Qed.
-
-Lemma vector_rev_sum_rev {n} (v : 'rV[R[i]]_n) :
- vector_rev_conj v ->
- forall i,
- Im ((v + vector_rev v) 0 i) = 0.
-Proof.
- intros.
- rewrite /vector_rev !mxE H rev_ordK.
- apply Cconj_im_0.
- rewrite rmorphD /= conjcK addrC//.
-Qed.
-
-Lemma vector_rev_reflect {n} (v : 'rV[R[i]]_n) i :
- vector_rev v 0 i = v 0 (rev_ord i).
-Proof.
- rewrite mxE//.
-Qed.
-
-Lemma vector_sum_rev {n} (v : 'rV[R[i]]_n) :
- vector_sum v = vector_sum (vector_rev v).
-Proof.
- unfold vector_sum, vector_rev.
- rewrite (reindex_inj rev_ord_inj)/=.
- apply eq_big_seq, ssrbool.in1W => x.
- rewrite mxE//.
-Qed.
-
-Lemma vector_sum_add {n} (a b : 'rV[R[i]]_n) :
- vector_sum (a + b) = vector_sum a + vector_sum b.
-Proof.
- unfold vector_sum.
- cut (\sum_(j < n) (a 0 j + b 0 j) = \sum_(j < n) a 0 j + \sum_(j < n) b 0 j).
- {
- intros HH.
- rewrite -HH/=.
- apply eq_big_seq, ssrbool.in1W => x.
- rewrite mxE//.
- }
- rewrite big_split //.
-Qed.
-
-Lemma Im_add (a b:R[i]) : Im (a + b) = Im a + Im b.
-Proof.
- now destruct a; destruct b; simpl.
-Qed.
-
-Lemma vector_sum_reals {n} (v : 'rV[R[i]]_n) :
- (forall i, Im (v 0 i) = 0) ->
- Im (vector_sum v) = 0.
-Proof.
- unfold vector_sum.
- apply big_rec; simpl; trivial.
- intros.
- rewrite Im_add H1 H0// addr0//.
-Qed.
-
-Lemma vector_rev_conj_sum {n} (v : 'rV[R[i]]_n) :
- vector_rev_conj v ->
- Im (vector_sum v) = 0%R.
-Proof.
- intros.
- cut (Im (vector_sum v + vector_sum (vector_rev v)) = 0).
- {
- rewrite -vector_sum_rev.
- destruct (vector_sum v); simpl.
- rewrite /add /zero/=.
- coq_lra.
- }
- rewrite -vector_sum_add vector_sum_reals//.
- now apply vector_rev_sum_rev.
-Qed.
-
-Lemma inner_product_as_sum {n} (v1 v2 : 'rV[R[i]]_n) :
- inner_prod v1 v2 = vector_sum (map2_mx (fun a b => a * b) v1 v2).
-Proof.
- rewrite /inner_prod /mulmx/= mxE /vector_sum/=.
- apply eq_big_seq, ssrbool.in1W => x.
- rewrite /map2_mx /trmx !mxE//.
-Qed.
-
-Lemma vector_rev_conj_inner {n} (v1 v2 : 'rV[R[i]]_n) :
- vector_rev_conj v1 ->
- vector_rev_conj v2 ->
- Im (inner_prod v1 v2) = 0.
-Proof.
- intros.
- rewrite inner_product_as_sum vector_rev_conj_sum//.
- now apply vector_rev_conj_mult.
-Qed.
-
-Lemma vector_rev_conj_odd_nth_roots (n : nat) :
- vector_rev_conj (odd_nth_roots (S n)).
-Proof.
- unfold vector_rev_conj, odd_nth_roots.
- intros.
- do 2 rewrite mxE.
- destruct (pow2_S (S (S n))).
- rewrite (eqP i0).
- rewrite nth_root_conj_alt.
- f_equal.
- rewrite -(eqP i0).
- rewrite (expnS _ (S n)).
- unfold rev_ord; simpl.
- rewrite Nat.mod_small; try lia.
- destruct i.
- simpl.
- lia.
-Qed.
-
-Lemma mv_rev_conj_real (n1 n2 : nat) (mat : 'M[R[i]]_(n1,n2)) (cl : 'cV[R[i]]_n2) :
- vector_rev_conj (cl^T) ->
- (forall i, vector_rev_conj (row i mat)) ->
- forall i,
- Im ((mat *m cl) i 0) = 0.
-Proof.
- intros.
- replace ((mat *m cl) i 0) with (((row i mat) *m cl) 0 0).
- - generalize (vector_rev_conj_inner (row i mat) (cl^T)); intros HH.
- unfold inner_prod in HH.
- rewrite trmxK in HH.
- now apply HH.
- - repeat rewrite mxE.
- apply eq_big_seq; intros ??.
- now repeat rewrite mxE.
-Qed.
-
-Lemma encode_mat_pow_odd_roots (n:nat) :
- let pmat := (peval_mat (odd_nth_roots (S n))) in
- forall i,
- row i pmat^T = (map_mx (fun c => c ^+ i) (odd_nth_roots (S n))).
-Proof.
- intros.
- unfold odd_nth_roots, pmat, peval_mat.
- apply matrixP; intros ??.
- now repeat rewrite mxE.
-Qed.
-
-Lemma mat_encode_real {n} (cl : 'cV[R[i]]_(2^(S n))) :
- let pmat := (peval_mat (odd_nth_roots (S n))) in
- let encmat := (conj_mat (pmat^T)) in
- vector_rev_conj (cl^T) ->
- forall i,
- Im ((encmat *m cl) i 0) = 0.
-Proof.
- intros.
- apply mv_rev_conj_real; trivial.
- intros.
- unfold encmat, conj_mat.
- rewrite <- map_row.
- apply vector_rev_conj_conj; simpl.
- rewrite encode_mat_pow_odd_roots.
- apply vector_rev_conj_exp.
- apply vector_rev_conj_odd_nth_roots.
-Qed.
-
-Lemma Re_Im_0 (c : R[i]) :
- Im c = 0 <-> c = RtoC (Re c).
-Proof.
- destruct c.
- unfold RtoC; simpl.
- split; intros.
- - now rewrite H.
- - inversion H.
- unfold zero; simpl; coq_lra.
-Qed.
-
-Lemma mat_encode_real_alt {n} (cl : 'cV[R[i]]_(2^(S n))) :
- let pmat := (peval_mat (odd_nth_roots (S n))) in
- let encmat := (conj_mat (pmat^T)) in
- vector_rev_conj (cl^T) ->
- encmat *m cl = map_mx (fun c => RtoC (Re c)) (encmat *m cl).
-Proof.
- intros.
- apply matrixP => x y.
- generalize (mat_encode_real cl H x) => HH.
- apply Re_Im_0 in HH.
- by rewrite ord1 {}HH !mxE.
-Qed.
-
-Definition vector_rev_col {n} {T} (v : 'cV[T]_n) :=
- \col_(i < n) v (rev_ord i) 0.
-
-Program Definition vector_reflect_conj {n} (cl : 'cV[R[i]]_(2^n)) : 'cV[R[i]]_(2^(S n)) :=
- col_mx cl (conj_mat (vector_rev_col cl)).
-Next Obligation.
- intros.
- rewrite expnS.
- lia.
-Qed.
-
-Lemma vector_reflect_conj_rev_conj {n} (cl : 'cV[R[i]]_(2^n)) :
- vector_rev_conj (vector_reflect_conj cl)^T.
-Proof.
- unfold vector_rev_conj, vector_reflect_conj.
- intros.
- unfold eq_rect.
- destruct (vector_reflect_conj_obligation_1 n).
- unfold conj_mat, vector_rev_col.
- repeat rewrite mxE.
- destruct (splitP i); destruct (splitP (rev_ord i)); unfold rev_ord in *; simpl in *.
- - destruct i; destruct j; destruct j0; simpl in *; lia.
- - rewrite !mxE/= conjcK.
- f_equal.
- destruct j.
- cut (m = 2 ^ n - k.+1)%nat.
- {
- intros; subst.
- f_equal; apply eqtype.bool_irrelevance.
- }
- destruct i; simpl in *; subst; lia.
- - rewrite !mxE/=.
- do 2 f_equal.
- destruct j.
- cut (2 ^ n - k.+1 = m)%nat.
- {
- intros; subst.
- f_equal; apply eqtype.bool_irrelevance.
- }
- destruct i; simpl in *; subst; lia.
- - destruct i; destruct k; destruct k0; simpl in *; lia.
-Qed.
-
-Definition CKKS_poly_encode {n} (cl : 'cV[R[i]]_(2^n)) : 'cV[R]_(2^(S n)) :=
- let pmat := (peval_mat (odd_nth_roots (S n))) in
- let encmat := (conj_mat (pmat^T)) in
- (inv (2 ^+ S n)) *:
- (map_mx (fun c => Re c) (encmat *m (vector_reflect_conj cl))).
-
-Definition int_to_zmodp (i : int) (p : nat) : 'Z_p := i %:~R.
-
-Definition col_to_poly {n} (cl : 'cV[R]_n) := rVpoly (cl^T).
-Definition col_to_poly2 {n} (cl : 'cV[int]_n) := rVpoly (cl^T).
-
-
-Definition red_poly (pol : {poly int}) p' :=
- map_poly (fun c => int_to_zmodp c p') pol.
-
-From mathcomp Require Import order.
-
-(* 0 <= rand < 1 *)
-Definition ran_round (x rand : R) :=
- let hi := up x in
- if (Order.lt ((IZR hi) - x) rand)%R then hi else (Zminus hi 1).
-
-Definition nearest_round (x : R) := ran_round x (1/2).
-
-Definition mx_round {n m} (mat : 'M[R]_(n,m)) : 'M[int]_(n,m) :=
- map_mx (fun r => ssrZ.int_of_Z (nearest_round r)) mat.
-
-Definition CKKS_poly_encode_Z {n} (cl : 'cV[R[i]]_(2^n)) : 'cV[int]_(2^(S n)) :=
- mx_round (CKKS_poly_encode cl).
-
-Definition vector_proj_coef {n} (v1 v2 : 'rV[R[i]]_n) :=
- (H_inner_prod v1 v2) / (H_inner_prod v2 v2).
-
-(* this is multiplication for vectors mod monic p *)
-Definition rv_mul_mod {n} (a b : 'rV[R]_n) (p : {poly R}) : 'rV[R]_n :=
- poly_rV (rVpoly(a) * rVpoly b %% p).
-
-(* poly rem x^n+1 *)
-Definition rv_mul_mod_xn_1 {n} (a b : 'rV[R]_n) (n : nat) : 'rV[R]_n :=
- let prod := (rVpoly a * rVpoly b) in
- poly_rV (take_poly n prod - drop_poly n prod).
-
-Lemma size_Xn_addC [R : ringType] (n :nat) (b : R) :
- seq.size ('X^n.+1 + b%:P) = n.+2.
-Proof.
- rewrite size_addl size_polyXn// (leq_ltn_trans (size_polyC_leq1 b))//.
-Qed.
-
-Lemma lead_coef_xn [R : unitRingType] (n : nat) (c : R) :
- lead_coef ('X^n.+1 + c%:P) \is a unit.
-Proof.
- rewrite lead_coefDl.
- - rewrite lead_coefXn unitr1 //.
- - rewrite size_polyXn.
- generalize (size_polyC_leq1 c); lia.
-Qed.
-
-Lemma poly_rem_xn [R : idomainType] (n : nat) (c : R) (a : {poly R}) :
- let p := 'X^n.+1 + polyC c in
- a %% p = take_poly n.+1 a + (-c)*: (drop_poly n.+1 a %% p).
-Proof.
- simpl.
- have H := lead_coef_xn n c.
- generalize (size_Xn_addC n c); intros.
- rewrite -{1}(poly_take_drop n.+1 a).
- rewrite Pdiv.IdomainUnit.modpD; trivial.
- f_equal.
- - rewrite modp_small; trivial.
- rewrite H0.
- generalize (size_take_poly n.+1 a); intros.
- apply H1.
- - rewrite -Pdiv.IdomainUnit.modp_mul; trivial.
- assert ('X^n.+1 %% ('X^n.+1 + c%:P) = -c%:P).
- {
- assert ('X^n.+1 = 1 * ('X^n.+1 + c%:P) -c%:P).
- {
- rewrite mul1r -addrA subrr addr0//.
- }
- rewrite -(Pdiv.IdomainUnit.modpP H H1)//.
- rewrite size_Xn_addC size_opp ltnS.
- rewrite (leq_trans (size_polyC_leq1 c))//.
- }
- rewrite H1 mulrC -Pdiv.IdomainUnit.modpZl // -mul_polyC polyCN //.
-Qed.
-
-From mathcomp Require Import polydiv.
-
-Lemma Xn_add_c_monic [R : ringType] n (c: R) :
- 'X^n.+1 + c%:P \is monic.
-Proof.
- rewrite monicE lead_coefDl.
- - rewrite lead_coefXn //.
- - rewrite size_polyXn.
- generalize (size_polyC_leq1 c); lia.
- Qed.
-
-Lemma poly_rem_xn_alt [R : comRingType] (n : nat) (c : R) (a : {poly R}) :
- let p := 'X^n.+1 + polyC c in
- Pdiv.CommonRing.rmodp a p =
- take_poly n.+1 a +
- (Pdiv.CommonRing.rmodp ((-c)*:(drop_poly n.+1 a)) p).
-Proof.
- simpl.
- have H := Xn_add_c_monic n c.
- generalize (size_Xn_addC n c); intros.
- rewrite -{1}(poly_take_drop n.+1 a).
- rewrite Pdiv.RingMonic.rmodpD; trivial.
- f_equal.
- - rewrite Pdiv.CommonRing.rmodp_small; trivial.
- rewrite H0.
- generalize (size_take_poly n.+1 a); intros.
- apply H1.
- - rewrite -Pdiv.RingMonic.rmodp_mulmr; trivial.
- assert (Pdiv.CommonRing.rmodp ('X^n.+1) ('X^n.+1 + c%:P) = -c%:P).
- {
- assert ('X^n.+1 = 1 * ('X^n.+1 + c%:P) -c%:P).
- {
- rewrite mul1r -addrA subrr addr0//.
- }
-
- rewrite {1}H1.
- rewrite Pdiv.RingMonic.rmodp_addl_mul_small; trivial.
- rewrite size_Xn_addC size_opp ltnS.
- rewrite (leq_trans (size_polyC_leq1 c))//.
- }
- rewrite H1 mulrC; trivial.
- rewrite -mul_polyC polyCN //.
-Qed.
-
-Require Import Recdef.
-Function poly_rem_xn_1 [R : ringType] (n : nat) (a : {poly R}) {measure seq.size a} : {poly R} :=
- let a1 := take_poly n.+1 a in
- let a2 := drop_poly n.+1 a in
- if a2 == 0 then a1 else
- a1 - poly_rem_xn_1 n a2.
-Proof.
- intros.
- rewrite size_drop_poly.
- enough (seq.size a <> 0)%nat by lia.
- intros eqq.
- rewrite drop_poly_eq0 in teq.
- - rewrite eq_refl// in teq.
- - rewrite eqq//.
-Defined.
-
-Arguments poly_rem_xn_1 [R] n a : rename.
-
-Lemma poly_rem_xn_id [R : ringType] n (a:{poly R}) : seq.size a <= n.+1 ->
- poly_rem_xn_1 n a = a.
-Proof.
- functional induction poly_rem_xn_1 n a => slt.
- - rewrite take_poly_id//.
- - rewrite drop_poly_eq0// eqxx// in y.
-Qed.
-
-Lemma poly_rem_xn_1_le [R : ringType] n (a:{poly R}) : is_true (seq.size (poly_rem_xn_1 n a) <= n.+1).
-Proof.
- functional induction poly_rem_xn_1 n a.
- - rewrite size_take_poly//.
- - rewrite (leq_trans (size_add _ _)) // size_opp geq_max IHp size_take_poly//.
- Qed.
-
-Lemma poly_size_0 {R : ringType} :
- seq.size (zero (poly_zmodType R)) = 0%nat.
-Proof.
- rewrite /zero /= size_polyC eqxx//.
-Qed.
-
-Lemma drop_poly_eq0_iff {R : ringType} (m : nat) (p : {poly R}) :
- seq.size p <= m <-> drop_poly m p = 0.
-Proof.
- split; intros.
- - now apply drop_poly_eq0.
- - generalize (size_drop_poly m p); intros.
- rewrite H poly_size_0 in H0.
- lia.
- Qed.
-
-Lemma poly_rem_xn_1_pmod [R : idomainType] n (a : {poly R}) :
- poly_rem_xn_1 n a = a %% ('X^n.+1 + 1%:P).
-Proof.
- functional induction poly_rem_xn_1 n a.
- - move => /eqP in e.
- apply drop_poly_eq0_iff in e.
- rewrite take_poly_id //.
- rewrite modp_small; trivial.
- rewrite size_Xn_addC.
- apply e.
- - rewrite poly_rem_xn IHp scaleN1r//.
-Qed.
-
-Definition equiv_xn_1 [R : ringType] (n : nat): rel {poly R} :=
- fun p => fun q => poly_rem_xn_1 n (p - q) == 0.
-
-Definition ideal_xn_1_pred [R : ringType] (n : nat) : pred (poly_zmodType R) :=
- fun p => poly_rem_xn_1 n p == 0.
-
-Lemma poly_rem_xn_1_1 [R : ringType] n :
- poly_rem_xn_1 (R:=R) n 1 = 1.
-Proof.
- apply poly_rem_xn_id.
- rewrite size_poly1//.
-Qed.
-
-Lemma poly_rem_xn_0_0 [R : ringType] n :
- poly_rem_xn_1 (R:=R) n 0 = 0.
-Proof.
- apply poly_rem_xn_id.
- rewrite size_poly0//.
-Qed.
-
-Lemma rmodp_monic_scale [R : ringType] (c : R) (p d : {poly R}) :
- d \is monic ->
- Pdiv.CommonRing.rmodp (c *: p) d = c *: (Pdiv.CommonRing.rmodp p d).
-Proof.
- move=> monic.
- have sz: seq.size (c *: (Pdiv.CommonRing.rmodp (R:=R) p d)) < seq.size d
- by rewrite (leq_ltn_trans (size_scale_leq c (Pdiv.CommonRing.rmodp (R:=R) p d)))//
- Pdiv.CommonRing.ltn_rmodpN0// monic_neq0//.
-
- rewrite -(Pdiv.RingMonic.rmodp_addl_mul_small monic (c *: Pdiv.CommonRing.rdivp (R:=R) p d) sz).
- by rewrite {1}(Pdiv.RingMonic.rdivp_eq monic p) scalerDr scalerAl.
-Qed.
-
-Lemma rmodp_monic_opp [R : ringType] (p d : {poly R}) :
- d \is monic ->
- Pdiv.CommonRing.rmodp (- p) d = -(Pdiv.CommonRing.rmodp p d).
-Proof.
- rewrite -!scaleN1r.
- by apply rmodp_monic_scale.
-Qed.
-
-Lemma poly_rem_xn_1_pmod_alt [R : comRingType] n (a : {poly R}) :
- poly_rem_xn_1 n a = Pdiv.CommonRing.rmodp a ('X^n.+1 + 1%:P).
-Proof.
- functional induction poly_rem_xn_1 n a.
- - move => /eqP in e.
- apply drop_poly_eq0_iff in e.
- rewrite take_poly_id //.
- rewrite Pdiv.Ring.rmodp_small; trivial.
- rewrite size_Xn_addC.
- apply e.
- - rewrite poly_rem_xn_alt.
- f_equal.
- rewrite IHp.
- rewrite -rmodp_monic_opp.
- + f_equal.
- rewrite scaleN1r //.
- + apply Xn_add_c_monic.
-Qed.
-
-Lemma poly_rem_xn_1_eq0_mul [R : comRingType] n (a b: {poly R}) :
- poly_rem_xn_1 n b = 0 ->
- poly_rem_xn_1 n (a * b) = 0.
-Proof.
- do 2 rewrite poly_rem_xn_1_pmod_alt; intros.
- rewrite -Pdiv.RingMonic.rmodp_mulmr.
- - rewrite H mulr0.
- apply Pdiv.Ring.rmod0p.
- - apply Xn_add_c_monic.
- Qed.
-
-Lemma poly_rem_xn_1_eq0_add [R : comRingType] n (a b: {poly R}) :
- poly_rem_xn_1 n a = 0 ->
- poly_rem_xn_1 n b = 0 ->
- poly_rem_xn_1 n (a + b) = 0.
-Proof.
- do 3 rewrite poly_rem_xn_1_pmod_alt; intros.
- rewrite Pdiv.RingMonic.rmodpD.
- - rewrite H H0 addr0 //.
- - apply Xn_add_c_monic.
- Qed.
-
-Lemma poly_rem_xn_1_eq0_opp [R : comRingType] n (a: {poly R}) :
- poly_rem_xn_1 n a = 0 ->
- poly_rem_xn_1 n (- a) = 0.
-Proof.
- replace (-a) with (-1 * a).
- - apply poly_rem_xn_1_eq0_mul.
- - apply Theory.mulN1r.
-Qed.
-
-Lemma ideal_xn_1_pred_proper [R : comRingType] n : proper_ideal (ideal_xn_1_pred (R:=R) n).
-Proof.
- rewrite /proper_ideal /in_mem /mem/= /ideal_xn_1_pred.
- split.
- - rewrite poly_rem_xn_1_1.
- apply oner_neq0.
- - rewrite /prop_in1/= /in_mem /= => a b /eqP /poly_rem_xn_1_eq0_mul->//.
-Qed.
-
-Lemma ideal_xn_1_pred_zmod [R : comRingType] n : zmodPred (ideal_xn_1_pred (R :=R) n).
-Proof.
- constructor.
- - constructor; [constructor|].
- constructor.
- + rewrite /in_mem //= /ideal_xn_1_pred poly_rem_xn_0_0 //.
- + rewrite /in_mem //= /prop_in2 /ideal_xn_1_pred => a b.
- rewrite /in_mem /mem /= => /eqP-eqq1 /eqP-eqq2.
- rewrite poly_rem_xn_1_eq0_add //.
- - rewrite /Pred.Exports.oppr_closed /mem /= /ideal_xn_1_pred => a.
- rewrite /in_mem /= => /eqP-eqq1.
- rewrite poly_rem_xn_1_eq0_opp //.
-Qed.
-
-Definition ideal_xn_1_idealr [R : comRingType] n : idealr (ideal_xn_1_pred (R := R) n)
- := MkIdeal (ideal_xn_1_pred_zmod n) (ideal_xn_1_pred_proper n).
-
-Definition qring_xn [R : comRingType] n
- := { ideal_quot (KeyedPred (ideal_xn_1_idealr (R := R) n)) }.
-
-Definition foo1_add [R : comRingType] n (a b : qring_xn (R:=R) n) := a + b.
-
-Section polyops.
-
- Context {T:comRingType}.
-
- Definition monic_poly := {p:{poly T} | (p \is monic) && (seq.size p > 1)}.
-
- Import Pdiv.
- Import CommonRing.
- Definition princ_ideal_pred (p : {poly T}) : pred {poly T} :=
- fun q => rmodp q p == 0.
-(*
- fun q => q %% p == 0.
-*)
-
-Lemma princ_ideal_proper (p : monic_poly) :
- proper_ideal (princ_ideal_pred (val p)).
-Proof.
- intros.
- unfold proper_ideal, princ_ideal_pred, in_mem, mem; split; simpl.
- - rewrite rmodp_small.
- + rewrite poly1_neq0//.
- + destruct (andP (valP p)).
- rewrite size_poly1 H0 //.
- - intros ??.
- rewrite -RingMonic.rmodp_mulmr // /in_mem/=.
- move => /eqP->.
- rewrite mulr0 rmod0p//.
- now destruct (andP (valP p)).
-Qed.
-
-Lemma princ_ideal_zmod (p : monic_poly) :
- zmodPred (princ_ideal_pred (val p)).
-Proof.
- constructor.
- - constructor; [constructor|].
- constructor.
- + rewrite /in_mem //= /princ_ideal_pred rmod0p//.
- + rewrite /in_mem //= /prop_in2 /princ_ideal_pred => a b.
- rewrite /in_mem /mem /= RingMonic.rmodpD // /=.
- * move => /eqP-> /eqP->.
- rewrite addr0//.
- * now destruct (andP (valP p)).
- - rewrite /Pred.Exports.oppr_closed /mem /= /princ_ideal_pred => a.
- rewrite /in_mem /= => /eqP-eqq1.
- destruct (andP (valP p)).
- rewrite rmodp_monic_opp // eqq1 oppr0 //.
-Qed.
-
-Definition princ_ideal (p : monic_poly) :
- idealr (princ_ideal_pred (val p))
- := MkIdeal (princ_ideal_zmod p) (princ_ideal_proper p).
-
-Definition qring (p : monic_poly)
- := { ideal_quot (KeyedPred (princ_ideal p)) }.
-
-Section example.
- Context (p: monic_poly).
-
- Definition foo_add (a b : qring p) := a + b.
- Definition foo_mul (a b : qring p) := a * b.
-
- Local Open Scope quotient_scope.
-
- Definition lift (a : {poly T}) : qring p
- := lift_cst (qring p) a.
-
- Definition proj (a : {poly T}) := (\pi_(qring p) a).
- Definition proj2 (a : {poly T}) : qring p := (\pi a).
-
- Example something (a b : {poly T}) := a == b %[mod (qring p)].
-
-End example.
-(*
-Require Import qpoly.
-Section qpoly.
- Context {T:comRingType}.
-Definition cyclotomic n : {poly T} := ('X^n.+1 + 1%:P).
-Definition qpoly_add n (p q : {qpoly (cyclotomic n)}) := p + q.
-Definition qpoly_mul n (p q : {qpoly (cyclotomic n)}) := p * q.
-Definition qpoly_inj n (p : {poly T}) := in_qpoly (cyclotomic n) p.
-End qpoly.
-*)
-
-End polyops.
-
-Section rmorphism.
-Lemma RtoC_is_rmorphism :
- rmorphism RtoC.
-Proof.
- constructor.
- - intros ??.
- unfold RtoC, add; simpl.
- f_equal.
- rewrite addrN //.
- - split.
- + intros ??.
- unfold RtoC, mul; simpl.
- f_equal.
- * rewrite mulr0 oppr0 addr0 //.
- * rewrite mulr0 mul0r addr0 //.
- + unfold RtoC, one; simpl.
- unfold one, real_complex_def; simpl.
- f_equal.
-Qed.
-
-Canonical RtoC_rmorphism := RMorphism RtoC_is_rmorphism.
-
-Lemma map_RtoC_is_rmorphism :
- rmorphism (map_poly RtoC).
-Proof.
- apply map_poly_is_rmorphism.
-Qed.
-
-Canonical map_RtoC_rmorphism := RMorphism map_RtoC_is_rmorphism.
-
-Definition peval_C (p : {poly R}) (x : C) : C :=
- (map_poly RtoC p).[x].
-
-Lemma ev_C_is_rmorphism (x:C) :
- rmorphism (fun (p : {poly R}) => peval_C p x).
-Proof.
- unfold peval_C.
- constructor.
- - intros ??.
- rewrite -horner_evalE !rmorphB //.
- - split.
- + intros ??.
- rewrite -horner_evalE !rmorphM //.
- + rewrite -horner_evalE !rmorph1 //.
-Qed.
-
-Canonical ev_C_rmorphism (x:R[i]) := RMorphism (ev_C_is_rmorphism x).
-
-Lemma comp_poly_is_rmorphism (p : {poly R}) :
- rmorphism (fun (q : {poly R}) => comp_poly p q).
-Proof.
- constructor.
- - intros ??.
- by rewrite comp_polyB.
- - split.
- + intros ??.
- by rewrite comp_polyM.
- + by rewrite comp_polyC.
-Qed.
-
-Canonical comp_poly_rmorphism (p : {poly R}) := RMorphism (comp_poly_is_rmorphism p).
-
-Lemma sum_conj (n : nat) (F : 'I_n -> R[i]) :
- conjc (\sum_(i < n) (F i)) = \sum_(i
- RtoC (Re c) = c.
-Proof.
- intros.
- rewrite /RtoC.
- destruct c; simpl.
- apply /eqP.
- rewrite eq_complex.
- rewrite /=.
- apply /andP.
- split; trivial.
- apply /eqP.
- simpl in H.
- by rewrite H.
-Qed.
-
-Lemma RtoC_cnorm (c : R[i]) :
- RtoC (cnorm c) = c * conjc c.
-Proof.
- rewrite /cnorm.
- apply RtoC_Re_Im0.
- destruct c.
- simpl.
- lra.
-Qed.
-
-Lemma RtoC_ctrace (c : R[i]) :
- RtoC (ctrace c) = c + conjc c.
-Proof.
- rewrite /ctrace.
- apply RtoC_Re_Im0.
- destruct c.
- simpl.
- lra.
-Qed.
-
-Lemma RtoC_opp (r : R) :
- RtoC (- r) = - RtoC r.
-Proof.
- rewrite /RtoC /=.
- apply /eqP.
- rewrite eq_complex /=.
- apply /andP; split; apply /eqP.
- - by rewrite /opp /=.
- - lra.
-Qed.
-
-Definition characteristic_polynomial (c : R[i]) : {poly R} :=
- 'X^2 + (- ctrace c) *: 'X + (cnorm c)%:P.
-
-Lemma size_charpoly (c : R[i]) :
- size (characteristic_polynomial c) = 3%N.
-Proof.
- rewrite /characteristic_polynomial.
- rewrite -addrA size_addl size_polyXn //.
- case : (eqVneq (ctrace c) 0); intros.
- - rewrite e oppr0 scale0r add0r.
- rewrite size_polyC.
- case : (eqVneq (cnorm c) 0); rewrite //.
- - rewrite size_addl.
- + rewrite size_scale.
- * rewrite size_polyX //.
- * by rewrite -eqr_opp opprK oppr0.
- + rewrite size_scale.
- * rewrite size_polyX size_polyC.
- case : (eqVneq (cnorm c) 0); rewrite //.
- * by rewrite -eqr_opp opprK oppr0.
-Qed.
-
-Lemma monic_charpoly (c : R[i]) :
- characteristic_polynomial c \is monic.
-Proof.
- apply /monicP.
- rewrite /characteristic_polynomial.
- rewrite -addrA.
- rewrite lead_coefDl.
- - apply lead_coefXn.
- - rewrite size_polyXn.
- case : (eqVneq (ctrace c) 0); intros.
- + rewrite e oppr0 scale0r add0r.
- rewrite size_polyC.
- case : (eqVneq (cnorm c) 0); rewrite //.
- + rewrite size_addl.
- * rewrite size_scale.
- -- rewrite size_polyX //.
- -- by rewrite -eqr_opp opprK oppr0.
- * rewrite size_scale.
- -- rewrite size_polyX size_polyC.
- case : (eqVneq (cnorm c) 0); rewrite //.
- -- by rewrite -eqr_opp opprK oppr0.
-Qed.
-
-Lemma characteristic_polynomial_correct (c : R[i]) :
- (map_poly RtoC (characteristic_polynomial c)).[c] = 0.
-Proof.
- rewrite /map_poly size_charpoly /characteristic_polynomial.
- rewrite horner_poly.
- rewrite big_ord_recl big_ord_recl big_ord1 /= expr0 mulr1 /bump addn0 expr1 /=.
- replace (addn (S 0) (S 0)) with 2%N by lia.
- rewrite !coefD !coefZ !coefX !coefC !coefXn /=.
- rewrite mulr0 mulr1 !addr0 !add0r.
- rewrite rmorph1 RtoC_cnorm RtoC_opp RtoC_ctrace mul1r expr2 opprD mulrDl.
- rewrite -addrA (addrC _ (c * c)) (addrA _ (c * c) _) -mulrDl (addrC _ c).
- by rewrite addrN mul0r add0r mulrC -mulrDl addrN mul0r.
-Qed.
-
-Lemma trace_conj (c : R[i]) :
- ctrace c = ctrace (conjc c).
-Proof.
- by rewrite /ctrace conjcK addrC.
-Qed.
-
-Lemma norm_conj (c : R[i]) :
- cnorm c = cnorm (conjc c).
-Proof.
- by rewrite /cnorm conjcK mulrC.
-Qed.
-
-Lemma cnorm_peval_C_conj (p : {poly R}) (c : C) :
- cnorm (peval_C p c) = cnorm (peval_C p (conjc c)).
-Proof.
- by rewrite norm_conj peval_C_conj.
-Qed.
-
-Lemma char_poly_conj (c : R[i]) :
- characteristic_polynomial c = characteristic_polynomial (conjc c).
-Proof.
- by rewrite /characteristic_polynomial trace_conj norm_conj.
-Qed.
-
-Lemma char_poly_conj_alt (c : R[i]) :
- (map_poly RtoC (characteristic_polynomial c)).[conjc c] = 0.
-Proof.
- rewrite char_poly_conj.
- apply characteristic_polynomial_correct.
-Qed.
-
-Lemma ctrace_eq (c1 c2 : R[i]) :
- ctrace c1 = ctrace c2 <-> c1 + conjc c1 = c2 + conjc c2.
-Proof.
- rewrite /ctrace.
- split; intros.
- - apply /eqP.
- rewrite eq_complex.
- apply /andP.
- split.
- + now apply /eqP.
- + by rewrite !ctrace_correct.
- - by rewrite H.
-Qed.
-
-Lemma cnorm_eq (c1 c2 : R[i]) :
- cnorm c1 = cnorm c2 <-> c1 * conjc c1 = c2 * conjc c2.
-Proof.
- rewrite /cnorm.
- split; intros.
- - apply /eqP.
- rewrite eq_complex.
- apply /andP.
- split.
- + now apply /eqP.
- + by rewrite !cnorm_correct.
- - by rewrite H.
-Qed.
-
-Lemma norm_trace_eq (c1 c2 : R[i]) :
- ctrace c1 = ctrace c2 /\ cnorm c1 = cnorm c2 <->
- c1 = c2 \/ c1 = conjc c2.
-Proof.
- split; intros.
- - destruct H.
- destruct c1; destruct c2.
- rewrite /ctrace /= in H.
- rewrite /cnorm /= in H0.
- assert (Re = Re0) by lra.
- rewrite H1 in H0.
- assert (Im *Im = Im0 * Im0) by lra.
- rewrite /conjc.
- rewrite H1.
- assert (Im = Im0 \/ Im = -Im0).
- {
- rewrite -!expr2 in H2.
- move /eqP in H2.
- rewrite eqf_sqr in H2.
- move /orP in H2.
- destruct H2.
- - by move /eqP in H2; left.
- - by move /eqP in H2; right.
- }
- destruct H3.
- + by rewrite H3; left.
- + by rewrite H3; right.
- - destruct H; rewrite H; try easy.
- by rewrite trace_conj norm_conj conjcK.
-Qed.
-
-Lemma charpoly_eq (c1 c2 : R[i]) :
- characteristic_polynomial c1 = characteristic_polynomial c2 <->
- c1 = c2 \/ c1 = conjc c2.
-Proof.
- split; intros.
- - rewrite /characteristic_polynomial in H.
- apply polyP in H.
- generalize (H 0%N); intros.
- generalize (H 1%N); intros.
- rewrite !coefD !coefZ !coefC !coefX !coefXn /= in H0.
- rewrite !coefD !coefZ !coefC !coefX !coefXn /= in H1.
- rewrite !mulr0 addr0 !add0r in H0.
- rewrite !addr0 !add0r !mulr1 in H1.
- apply norm_trace_eq.
- split; trivial.
- lra.
- - destruct H.
- + by rewrite H.
- + by rewrite H char_poly_conj conjcK.
-Qed.
-
-Lemma poly2_expand {S:comRingType} (c1 c2 : S) :
- 'X^2 - (c1 + c2)*: 'X + (c1*c2)%:P =
- ('X - c1%:P) * ('X - c2%:P).
-Proof.
- rewrite mulrDr !mulrDl addrA.
- f_equal.
- - rewrite scalerDl -expr2 -addrA.
- f_equal.
- rewrite (mulrC 'X _) opprD -!scaleNr -!mul_polyC.
- by f_equal; rewrite polyCN.
- - by rewrite mulrNN -scale_polyC mul_polyC.
-Qed.
-
-Lemma RtoCR n : RtoC n%:R = n%:R.
-Proof.
- unfold RtoC.
- induction n.
- - now rewrite mulr0n.
- - rewrite mulrSr /= mulrS -IHn.
- apply /eqP.
- rewrite eq_complex /= addrC addr0 /=.
- by apply /andP.
-Qed.
-
-Lemma charpoly_factor (c : R[i]) :
- map_poly RtoC (characteristic_polynomial c) =
- ('X - c%:P) * ('X - (conjc c)%:P).
-Proof.
- move: (size_charpoly c).
- rewrite -poly2_expand /characteristic_polynomial /ctrace /cnorm.
- rewrite map_polyE=>sz.
- apply/polyP=>i.
- rewrite coef_Poly.
- case/orP: (leqVgt 3 i)=>ibd.
- - rewrite !(coefD, coefZ, coefC, coefN, coefX, coefXn).
- replace (i == 2%N) with false by lia.
- replace (i == 1%N) with false by lia.
- replace (i == 0%N) with false by lia.
- rewrite /= mulr0 mulr0n oppr0 !addr0.
- rewrite nth_default //.
- by rewrite size_map sz.
- - rewrite (nth_map 0); [| by rewrite sz].
- rewrite !(coefD, coefZ, coefC, coefN, coefX, coefXn).
- rewrite !(rmorphD, rmorphM, rmorphN) /=.
- rewrite !RtoCR RtoC_ctrace mulNr.
- f_equal.
- case : (eqVneq i 0%N).
- + by rewrite RtoC_cnorm.
- + by rewrite /RtoC.
- Qed.
-
-Lemma charpoly_irreducible (c : R[i]) :
- c != conjc c ->
- irreducible_poly (characteristic_polynomial c).
-Proof.
- intros.
- assert (1 < size (characteristic_polynomial c) <= 4).
- {
- rewrite size_charpoly //.
- }
- apply (cubic_irreducible H0).
- intros.
- apply /negP.
- unfold not; intros.
- assert (root (map_poly RtoC (characteristic_polynomial c)) (RtoC x)).
- {
- by apply rmorph_root.
- }
- rewrite charpoly_factor in H2.
- unfold root in H2.
- rewrite hornerM hornerD mulf_eq0 in H2.
- move /orP in H2.
- destruct H2.
- - move /eqP in H2.
- rewrite hornerX hornerN hornerC in H2.
- apply (f_equal (fun z => z+ c)) in H2.
- rewrite -addrA add0r (addrC _ c) addrN addr0 in H2.
- by rewrite -H2 /RtoC /= eq_complex /= oppr0 !eq_refl /= in H.
- - move /eqP in H2.
- rewrite hornerD hornerX hornerN hornerC in H2.
- apply (f_equal (fun z => z+ conjc c)) in H2.
- rewrite -addrA add0r (addrC _ (conjc c)) addrN addr0 in H2.
- apply (f_equal (fun z => conjc z)) in H2.
- rewrite conjcK /RtoC /= in H2.
- by rewrite -H2 /RtoC /= eq_complex /= opprK oppr0 !eq_refl /= in H.
- Qed.
-
-Lemma charpoly_square (c : R[i]) :
- c == conjc c ->
- characteristic_polynomial c = ('X - (Re c)%:P)^+2.
-Proof.
- intros.
- move /eqP in H.
- rewrite expr2 /characteristic_polynomial /ctrace /cnorm H conjcK /= -H -poly2_expand.
- f_equal.
- - f_equal.
- rewrite -scaleNr.
- f_equal.
- f_equal.
- destruct c.
- by simpl.
- - f_equal.
- destruct c.
- simpl.
- move /eqP in H.
- rewrite eq_complex /= in H.
- move /andP in H.
- destruct H.
- move /eqP in H0.
- assert (Im = 0) by lra.
- rewrite H1.
- lra.
-Qed.
-
-Lemma charpoly_coprime_case1 (c1 c2 : R[i]) :
- c1 != conjc c1 ->
- c2 != conjc c2 ->
- characteristic_polynomial c1 != characteristic_polynomial c2 ->
- coprimep (characteristic_polynomial c1) (characteristic_polynomial c2).
-Proof.
- intros.
- apply /coprimepP.
- intros.
- apply charpoly_irreducible in H.
- apply charpoly_irreducible in H0.
- specialize (H d); intros.
- specialize (H0 d); intros.
- rewrite -size_poly_eq1.
- case : (eqVneq (size d) 1%N); trivial.
- intros.
- specialize (H4 i H2).
- specialize (H5 i H3).
- rewrite eqp_sym in H4.
- generalize (eqp_trans H4 H5); intros.
- rewrite eqp_monic in H6.
- - by rewrite H6 in H1.
- - apply monic_charpoly.
- - apply monic_charpoly.
-Qed.
-
-Lemma charpoly_coprime_case2 (c1 c2 : R[i]) :
- c1 != conjc c1 ->
- c2 == conjc c2 ->
- characteristic_polynomial c1 != characteristic_polynomial c2 ->
- coprimep (characteristic_polynomial c1) (characteristic_polynomial c2).
-Proof.
- intros.
- apply charpoly_square in H0.
- apply /coprimepP.
- intros.
- apply charpoly_irreducible in H.
- specialize (H d); intros.
- rewrite -size_poly_eq1.
- case : (eqVneq (size d) 1%N); trivial.
- intros.
- specialize (H4 i H2).
- assert (size d == size (characteristic_polynomial c2)).
- {
- generalize (eqp_size H4); intros.
- rewrite size_charpoly.
- rewrite size_charpoly in H5.
- by rewrite H5.
- }
- rewrite Pdiv.CommonIdomain.dvdp_size_eqp in H5; trivial.
- rewrite eqp_sym in H4.
- generalize (eqp_trans H4 H5); intros.
- rewrite eqp_monic in H6.
- - by rewrite H6 in H1.
- - apply monic_charpoly.
- - apply monic_charpoly.
-Qed.
-
-Lemma charpoly_coprime_case3 (c1 c2 : R[i]) :
- c1 == conjc c1 ->
- c2 == conjc c2 ->
- characteristic_polynomial c1 != characteristic_polynomial c2 ->
- coprimep (characteristic_polynomial c1) (characteristic_polynomial c2).
-Proof.
- intros.
- apply charpoly_square in H.
- apply charpoly_square in H0.
- rewrite H H0.
- rewrite H H0 in H1.
- apply coprimep_expl.
- apply coprimep_expr.
- apply coprimep_XsubC2.
- apply /negP.
- unfold not; intros.
- rewrite subr_eq0 in H2.
- move /eqP in H2.
- rewrite H2 in H1.
- move /negP in H1.
- by apply H1.
-Qed.
-
-Lemma charpoly_comprime (c1 c2 : R[i]) :
- characteristic_polynomial c1 != characteristic_polynomial c2 ->
- coprimep (characteristic_polynomial c1) (characteristic_polynomial c2).
-Proof.
- case : (boolP (c1 == conjc c1)); intros.
- - case : (boolP (c2 == conjc c2)); intros.
- + by apply charpoly_coprime_case3.
- + rewrite coprimep_sym.
- rewrite eq_sym in H.
- by apply charpoly_coprime_case2.
- - case : (boolP (c2 == conjc c2)); intros.
- + by apply charpoly_coprime_case2.
- + by apply charpoly_coprime_case1.
-Qed.
-
- Lemma ev_C_1 :
- forall (x : C), peval_C 1 x = 1.
- Proof.
- apply ev_C_is_rmorphism.
- Qed.
-
- Definition peval_C_ker_pred (x : C) : pred {poly R} :=
- fun p => peval_C p x == 0.
-
- Lemma peval_C_ker_proper (x : C) :
- proper_ideal (peval_C_ker_pred x).
- Proof.
- split.
- - by rewrite /peval_C_ker_pred /in_mem /mem /= ev_C_1 oner_neq0.
- - move => a b.
- rewrite /in_mem /=.
- rewrite /peval_C_ker_pred.
- case: (ev_C_is_rmorphism x) => _ -> /eqP->.
- by rewrite mulr0.
- Qed.
-
- Lemma peval_C_ker_zmod (x : C) :
- zmodPred (peval_C_ker_pred x).
- Proof.
- constructor.
- - constructor; [constructor|].
- constructor.
- + rewrite /in_mem //= /peval_C_ker_pred /peval_C.
- unfold map_poly.
- rewrite poly_size_0.
- rewrite (eq_poly (fun _ => 0)).
- * rewrite -{2}(horner0 x).
- apply /eqP.
- f_equal.
- apply /polyP => i /=.
- rewrite coef_poly coefC /=.
- f_equal.
- by case: (i == 0)%nat.
- * move=> i ilt.
- rewrite coefC.
- by case: (i == 0)%nat.
- + rewrite /in_mem //= /prop_in2 /peval_C_ker_pred => a b.
- rewrite /in_mem /mem /= .
- generalize (raddfD (ev_C_rmorphism x)); intros.
- simpl in H; rewrite H.
- revert H0 H1.
- move => /eqP -> /eqP->.
- rewrite add0r //.
- - rewrite /Pred.Exports.oppr_closed /mem /= /peval_C_ker_pred => a.
- rewrite /in_mem /= => /eqP-eqq1.
- generalize (raddfN (ev_C_rmorphism x) a); intros.
- simpl in H.
- rewrite H eqq1 oppr0 //.
- Qed.
-
- Definition peval_C_ker_is_ideal (x:C) :
- idealr (peval_C_ker_pred x)
- := MkIdeal (peval_C_ker_zmod x) (peval_C_ker_proper x).
-
- Canonical peval_C_ker_ideal (x:C) := KeyedPred (peval_C_ker_is_ideal x).
-
- Definition peval_C_ker_quot_ring (x:C)
- := { ideal_quot (peval_C_ker_ideal x) }.
-
- Local Open Scope quotient_scope.
-
- Definition peval_C_quot (x:C) : peval_C_ker_quot_ring x -> C
- := lift_fun1 (peval_C_ker_quot_ring x) (fun p => peval_C p x).
-
- Lemma pi_peval_C_quot x : {mono (\pi_(peval_C_ker_quot_ring x)) : p / peval_C p x >-> peval_C_quot x p}.
- Proof.
- move=> p.
- rewrite /peval_C_quot -eq_lock.
- case: piP => a /EquivQuot.eqmodP.
- rewrite /Quotient.equiv_equiv /Quotient.equiv /in_mem /mem /= /peval_C_ker_pred.
- destruct (ev_C_is_rmorphism x).
- rewrite base => eqq.
- move=> /eqP in eqq.
- apply (f_equal (fun z => z + peval_C a x)) in eqq.
- by rewrite -addrA add0r (addrC _ (peval_C a x)) addrN addr0 in eqq.
- Qed.
-
- Lemma peval_C_quot1 x : peval_C_quot x 1 = 1.
- Proof.
- rewrite /one /= /Quotient.one /= /one /= /locked.
- destruct master_key.
- rewrite pi_peval_C_quot /peval_C.
- rewrite /map_poly size_poly1.
- rewrite horner_poly big_ord1 expr0 mulr1.
- by rewrite -horner_coef0 hornerC.
- Qed.
-
- Lemma peval_quot_C_is_rmorphism (c:C): rmorphism (peval_C_quot c).
- Proof.
- split => [x|].
- - apply quotP=> y <-.
- revert x.
- apply quotP => x <-.
- rewrite !reprK.
- rewrite !pi_peval_C_quot.
- rewrite /peval_C_quot -!eq_lock.
- rewrite -pi_is_additive.
- case: piP => y' /eqquotP.
- rewrite /peval_C_ker_pred /in_mem/mem/= => /eqP.
- generalize (raddfD (ev_C_rmorphism c)); intros peval_C_add.
- generalize (raddfN (ev_C_rmorphism c)); intros peval_C_opp.
- generalize (peval_C_add (x-y) (-y')); intros add1.
- simpl in add1; rewrite add1.
- specialize (peval_C_add x (-y)); simpl in peval_C_add.
- rewrite peval_C_add.
- generalize (peval_C_opp y); intro opp1.
- simpl in opp1; rewrite opp1.
- specialize (peval_C_opp y'); simpl in peval_C_opp.
- rewrite peval_C_opp.
- intro HH.
- apply (f_equal (fun z => z + peval_C y' c)) in HH.
- rewrite add0r -!addrA in HH.
- rewrite (addrC _ (peval_C y' c)) addrN addr0 in HH.
- by rewrite -HH.
- - constructor.
- + move => x.
- apply quotP=> y <-.
- revert x.
- apply quotP => x <-.
- rewrite !reprK.
- rewrite !pi_peval_C_quot.
- rewrite /peval_C_quot -!eq_lock.
- rewrite -pi_is_multiplicative.
- case: piP => y' /eqquotP.
- rewrite /peval_C_ker_pred /in_mem/mem/= => /eqP.
- destruct (ev_C_is_rmorphism c) as [? [??]].
- specialize (base (x * y) y'); simpl in base.
- rewrite base.
- specialize (m x y); simpl in m.
- rewrite m.
- intro HH.
- apply (f_equal (fun z => z + peval_C y' c)) in HH.
- rewrite add0r -!addrA in HH.
- rewrite (addrC _ (peval_C y' c)) addrN addr0 in HH.
- by rewrite -HH.
- + by apply peval_C_quot1.
- Qed.
-
- Lemma peval_C_quot_is_injective (c:C) :
- injective (peval_C_quot c).
- Proof.
- intros x y.
- rewrite /peval_C_quot -!eq_lock.
- rewrite -{2}[x]reprK -{2}[y]reprK.
- move: (repr x) (repr y) => {x} {y} x y eqq.
- apply/eqquotP.
- rewrite /Quotient.equiv/=.
- rewrite /peval_C_ker_pred /in_mem/mem/=.
- apply/eqP.
- destruct (ev_C_is_rmorphism c).
- specialize (base x y).
- simpl in base.
- by rewrite eqq addrN in base.
- Qed.
-
- Lemma cval_decomp (c1 c2 : C) :
- Im c1 != 0 ->
- {a : R * R | (a.1%:C * c1 + a.2%:C)%C = c2}.
- Proof.
- intros.
- exists ((Im c2/Im c1), (Re c2 - (Im c2 / Im c1)*Re c1)).
- apply /eqP.
- rewrite eq_complex.
- apply /andP.
- destruct c1.
- destruct c2.
- simpl.
- rewrite !mul0r !addr0 subr0.
- - split.
- + apply /eqP.
- rewrite addrC.
- generalize (subrKA (Im0 / Im * Re) Re0 0); intros.
- by rewrite !addr0 in H0.
- + rewrite -mulrA.
- rewrite (mulrC _ Im).
- rewrite divff //.
- by rewrite mulr1.
- Qed.
-
- Lemma peval_C_decomp (c1 c2 : C) :
- Im c1 != 0 ->
- {p : {poly R} | peval_C p c1 = c2}.
- Proof.
- intros.
- destruct (cval_decomp c1 c2 H).
- exists (x.1 *: 'X + x.2 %:P).
- rewrite /peval_C -e.
- case: (eqVneq x.1 0) => [-> |].
- - rewrite scale0r add0r /map_poly.
- rewrite horner_poly mul0r add0r.
- case: (eqVneq x.2 0) => [-> |].
- + by rewrite size_poly0 big_ord0.
- + intros.
- rewrite size_polyC.
- rewrite i // big_ord1 expr0 mulr1 coefC //.
- - intros.
- rewrite /map_poly size_addl size_scale //; rewrite size_polyX.
- + rewrite horner_poly.
- rewrite big_ord_recl big_ord1 /= expr0 mulr1 /bump /= addn0 expr1.
- rewrite !coefD !coefZ !coefX !coefC /=.
- rewrite mulr0 mulr1 addr0 add0r.
- by rewrite addrC mulrC.
- + rewrite size_polyC.
- by case: (eqVneq x.2 0).
- Qed.
-
- Lemma peval_C_quot_is_surjective (c c2 :C) :
- Im c != 0 ->
- {p: peval_C_ker_quot_ring c | peval_C_quot c p = c2}.
- Proof.
- intros.
- destruct (peval_C_decomp c c2); trivial.
- exists (\pi_(peval_C_ker_quot_ring c) x).
- rewrite -e.
- apply pi_peval_C_quot.
- Qed.
-
-
- Lemma peval_C_quot_is_bijective (c :C) :
- Im c != 0 ->
- bijective (peval_C_quot c).
- Proof.
- intros imn0.
- pose g : R[i] -> peval_C_ker_quot_ring c :=
- fun c2 => sval (peval_C_quot_is_surjective c c2 imn0).
- apply Bijective with (g := g).
- - intros ?.
- assert (peval_C_quot c (g (peval_C_quot c x)) = peval_C_quot c x).
- {
- rewrite /g.
- destruct (peval_C_quot_is_surjective c (peval_C_quot c x) imn0).
- by simpl.
- }
- by apply peval_C_quot_is_injective in H.
- - intros ?.
- rewrite /g.
- destruct (peval_C_quot_is_surjective c x imn0).
- by simpl.
- Qed.
-
-
-End rmorphism.
-
-Module matrix_ring.
-Section matrixRing.
- Section base_ring.
- Context {T:ringType}.
-
- Variable (n m : nat).
-
- Definition MR_mul (A B : 'M[T]_(S n, S m)) := map2_mx (fun (a b : T) => a * b) A B.
- Definition MR1 : 'M[T]_(S n,S m) := const_mx 1.
- Definition MR0: 'M[T]_(S n,S m) := const_mx 0.
- Lemma MR_mulA : associative MR_mul.
- Proof.
- by move=> A B C; apply/matrixP=> i j; rewrite !mxE Monoid.mulmA.
- Qed.
-
- Lemma MR_mul1z : left_id MR1 MR_mul.
- Proof.
- intros ?.
- unfold MR_mul, MR1.
- unfold map2_mx, const_mx.
- apply matrixP; intros ??.
- rewrite !mxE mul1r //.
- Qed.
-
- Lemma MR_mulz1 : right_id MR1 MR_mul.
- Proof.
- intros ?.
- unfold MR_mul, MR1.
- unfold map2_mx, const_mx.
- apply matrixP; intros ??.
- rewrite !mxE mulr1 //.
- Qed.
-
- Fact mul_0MR : left_zero MR0 MR_mul.
- Proof.
- intros ?.
- unfold MR_mul, MR1.
- unfold map2_mx, const_mx.
- apply matrixP; intros ??.
- rewrite !mxE mul0r //.
- Qed.
-
- Fact mul_MR0 : right_zero MR0 MR_mul.
- Proof.
- intros ?.
- unfold MR_mul, MR1.
- unfold map2_mx, const_mx.
- apply matrixP; intros ??.
- rewrite !mxE mulr0 //.
- Qed.
-
- Lemma MR_mul_addr : right_distributive MR_mul (@addmx T (S n) (S m)).
- Proof.
- move=> A B C; apply/matrixP=> i j.
- unfold MR_mul, addmx, map2_mx.
- rewrite !mxE mulrDr //.
- Qed.
-
- Lemma MR_mul_addl : left_distributive MR_mul (@addmx T (S n) (S m)).
- Proof.
- move=> A B C; apply/matrixP=> i j.
- unfold MR_mul, addmx, map2_mx.
- rewrite !mxE mulrDl //.
- Qed.
-
- Fact MR1_neq0 : MR1 != MR0.
- Proof.
- unfold MR1, MR0, const_mx.
- apply /eqP/matrixP.
- move/(_ ord0 ord0).
- rewrite !mxE.
- apply/eqP/oner_neq0.
- Qed.
-
- Definition MR_ringMixin :=
- RingMixin MR_mulA MR_mul1z MR_mulz1 MR_mul_addl MR_mul_addr MR1_neq0.
-
- Canonical MR_ringType := Eval hnf in RingType 'M[T]_(S n,S m) MR_ringMixin.
- End base_ring.
-
- Section com_ring.
-
- Context {T:comRingType}.
- Variable (n m : nat).
-
- Lemma MR_mulC : commutative (@MR_mul T n m).
- Proof.
- intros ??.
- unfold MR_mul, map2_mx.
- apply eq_mx; intros ??.
- rewrite mulrC //.
- Qed.
-
- Canonical MR_comRingType := Eval hnf in ComRingType (@MR_ringType T n m) MR_mulC.
-
-(* Definition MR_embed (val : 'M[T]_(S n,S m)) : MR_comRingType := val. *)
-
- End com_ring.
-
- Section unit_ring.
- Context {T:unitRingType}.
- Variable (n m : nat).
-
- Definition MR_unit : pred 'M[T]_(S n,S m)
- := [pred m : 'M[T]_(S n,S m) | [forall i, [forall j, (m i j) \is a unit]]].
-
- Definition MR_inv (x:'M[T]_(S n,S m)) : 'M[T]_(S n,S m)
- := if MR_unit x then map_mx inv x else x.
-
- Lemma MR_mulVr : ({in MR_unit, left_inverse 1 MR_inv (@mul (MR_ringType _ _))}).
- Proof.
- move=> M.
- rewrite /MR_inv.
- rewrite -topredE /= => eqq1.
- rewrite eqq1.
- apply/matrixP => x y.
- rewrite !mxE.
- apply mulVr.
- move: eqq1.
- rewrite/MR_unit simpl_predE.
- by move/forallP /(_ x)/forallP/(_ y).
- Qed.
-
- Lemma MR_mulrV : ({in MR_unit, right_inverse 1 MR_inv (@mul (MR_ringType _ _))}).
- Proof.
- move=> M.
- rewrite /MR_inv.
- rewrite -topredE /= => eqq1.
- rewrite eqq1.
- apply/matrixP => x y.
- rewrite !mxE.
- apply mulrV.
- move: eqq1.
- rewrite/MR_unit simpl_predE.
- by move/forallP /(_ x)/forallP/(_ y).
- Qed.
-
- Lemma MR_unitP (x y : MR_ringType n m) : y * x = 1 /\ x * y = 1 -> MR_unit x.
- Proof.
- move=>[/matrixP-linv /matrixP-rinv].
- rewrite/MR_unit/simpl_predE.
- apply/forallP => i.
- apply/forallP => j.
- apply/unitrP.
- exists (y i j).
- move: linv => /(_ i j).
- move: rinv => /(_ i j).
- by rewrite !mxE => -> ->.
- Qed.
-
- Lemma MR_inv0id : {in [predC MR_unit], MR_inv =1 id}.
- Proof.
- move => x.
- by rewrite -topredE /= -topredE /= /MR_inv => /Bool.negb_true_iff->.
- Qed.
-
- Definition MR_unitRingMixin : UnitRing.mixin_of (Ring.Pack (Ring.class (MR_ringType n m)))
- := @UnitRingMixin (MR_ringType _ _) MR_unit MR_inv MR_mulVr MR_mulrV MR_unitP MR_inv0id.
-
- Canonical MR_unitRingType := Eval hnf in UnitRingType (@MR_ringType T n m) MR_unitRingMixin.
-
- End unit_ring.
-
- End matrixRing.
-End matrix_ring.
-
-Section eval_vectors.
-
- Import matrix_ring.
-
- Context {n} (vals : 'rV[R[i]]_(n.+1)).
-
- Definition mx_eval (p : {poly R}) : MR_comRingType 0 n :=
- (map_mx (fun x => (map_poly RtoC p).[x]) vals).
-
- Lemma mx_evalC c : mx_eval c%:P = const_mx (RtoC c).
- Proof.
- apply matrixP=> a b.
- rewrite !mxE.
- by rewrite (map_polyC RtoC_rmorphism) /= hornerC.
- Qed.
-
- Lemma mx_eval_is_rmorphism :
- rmorphism (fun p => mx_eval p).
- Proof.
- constructor.
- - move=> x y.
- apply matrixP=> a b.
- rewrite !mxE.
- by apply ev_C_is_rmorphism.
- - split.
- + move=> x y.
- apply matrixP=> a b.
- rewrite !mxE.
- by apply ev_C_is_rmorphism.
- + by rewrite mx_evalC.
- Qed.
-
- Canonical mx_eval_rmorphism
- : {rmorphism poly_ringType R_ringType -> MR_comRingType 0 n}
- := RMorphism (mx_eval_is_rmorphism).
-
- Lemma mx_eval_1 :
- mx_eval 1 = 1.
- Proof.
- apply mx_eval_is_rmorphism.
- Qed.
-
- Definition mx_eval_ker_pred : pred {poly R} :=
- fun p => mx_eval p == 0.
-
- Lemma mx_eval_ker_proper :
- proper_ideal (mx_eval_ker_pred).
- Proof.
- split.
- - by rewrite /mx_eval_ker_pred /in_mem /mem /= mx_eval_1 oner_neq0.
- - move => a b.
- rewrite /in_mem /=.
- rewrite /mx_eval_ker_pred.
- case: (mx_eval_is_rmorphism) => _ -> /eqP->.
- by rewrite mulr0.
- Qed.
-
- Lemma mx_eval_ker_zmod :
- zmodPred (mx_eval_ker_pred).
- Proof.
- constructor.
- - constructor; [constructor|].
- constructor.
- + rewrite /in_mem //= /mx_eval_ker_pred /mx_eval /map_mx.
- apply/eqP/matrixP=>a b.
- rewrite !mxE.
- unfold map_poly.
- generalize (hornerC (RtoC 0) (vals a b)); intros.
- rewrite poly_size_0.
- rewrite (eq_poly (fun _ => 0)).
- * rewrite -{2}(horner0 (vals a b)).
- f_equal.
- apply /polyP => i /=.
- rewrite coef_poly coefC /=.
- by case: (i == 0)%nat.
- * move=> i ilt.
- rewrite coefC.
- by case: (i == 0)%nat.
- + rewrite /in_mem //= /prop_in2 /mx_eval_ker_pred => a b.
- rewrite /in_mem /mem /= .
- generalize (raddfD (mx_eval_rmorphism) a b); intros.
- simpl in H; rewrite H.
- revert H0 H1.
- move => /eqP -> /eqP->.
- rewrite add0r //.
- - rewrite /Pred.Exports.oppr_closed /mem /= /mx_eval_ker_pred => a.
- rewrite /in_mem /= => /eqP-eqq1.
- generalize (raddfN (mx_eval_rmorphism) a); intros.
- simpl in H.
- rewrite H eqq1 oppr0 //.
- Qed.
-
- Definition mx_eval_ker_is_ideal :
- idealr mx_eval_ker_pred
- := MkIdeal mx_eval_ker_zmod mx_eval_ker_proper.
-
- Canonical mx_eval_ker_ideal := KeyedPred (mx_eval_ker_is_ideal).
-
- Definition mx_eval_ker_quot_ring
- := { ideal_quot mx_eval_ker_ideal }.
-
- Local Open Scope quotient_scope.
-
- Definition mx_eval_quot : mx_eval_ker_quot_ring -> MR_comRingType 0 n
- := lift_fun1 mx_eval_ker_quot_ring mx_eval.
-
- Lemma pi_mx_eval_quot : {mono (\pi_mx_eval_ker_quot_ring) : x / mx_eval x >-> mx_eval_quot x}.
- Proof.
- move=> x.
- rewrite /mx_eval_quot -eq_lock.
- case: piP => a /EquivQuot.eqmodP.
- rewrite /Quotient.equiv_equiv /Quotient.equiv /in_mem /mem /= /mx_eval_ker_pred.
- destruct mx_eval_is_rmorphism.
- rewrite base => eqq.
- move=> /eqP in eqq.
- apply (f_equal (fun z => z + mx_eval a)) in eqq.
- by rewrite -addrA add0r (addrC _ (mx_eval a)) addrN addr0 in eqq.
- Qed.
-
- Lemma mx_eval_quotC c :
- mx_eval_quot (\pi_({ideal_quot mx_eval_ker_ideal}) c%:P) = const_mx (RtoC c).
- Proof.
- by rewrite pi_mx_eval_quot mx_evalC.
- Qed.
-
- Lemma mx_eval_quot1 : mx_eval_quot 1 = 1.
- Proof.
- rewrite /one /= /Quotient.one /= /one /= /locked.
- destruct master_key.
- by rewrite mx_eval_quotC.
- Qed.
-
- Lemma mx_eval_quot_is_rmorphism : rmorphism mx_eval_quot.
- Proof.
- split => [x|].
- - apply quotP=> y <-.
- revert x.
- apply quotP => x <-.
- rewrite !reprK.
- rewrite !pi_mx_eval_quot.
- rewrite /mx_eval_quot -!eq_lock.
- rewrite -pi_is_additive.
- case: piP => y' /eqquotP.
- rewrite /mx_eval_ker_pred /in_mem/mem/= => /eqP.
- generalize (raddfD (mx_eval_rmorphism)); intros mx_eval_add.
- generalize (raddfN (mx_eval_rmorphism)); intros mx_eval_opp.
- generalize (mx_eval_add (x-y) (-y')); intros add1.
- simpl in add1; rewrite add1.
- specialize (mx_eval_add x (-y)); simpl in mx_eval_add.
- rewrite mx_eval_add.
- generalize (mx_eval_opp y); intro opp1.
- simpl in opp1; rewrite opp1.
- specialize (mx_eval_opp y'); simpl in mx_eval_opp.
- rewrite mx_eval_opp.
- intro HH.
- apply (f_equal (fun z => z + mx_eval y')) in HH.
- rewrite add0r -!addrA in HH.
- rewrite (addrC _ (mx_eval y')) addrN addr0 in HH.
- by rewrite -HH.
- - constructor.
- + move => x.
- apply quotP=> y <-.
- revert x.
- apply quotP => x <-.
- rewrite !reprK.
- rewrite !pi_mx_eval_quot.
- rewrite /mx_eval_quot -!eq_lock.
- rewrite -pi_is_multiplicative.
- case: piP => y' /eqquotP.
- rewrite /mx_eval_ker_pred /in_mem/mem/= => /eqP.
- destruct mx_eval_is_rmorphism as [? [??]].
- specialize (base (x * y) y'); simpl in base.
- rewrite base.
- specialize (m x y); simpl in m.
- rewrite m.
- intro HH.
- apply (f_equal (fun z => z + mx_eval y')) in HH.
- rewrite add0r -!addrA in HH.
- rewrite (addrC _ (mx_eval y')) addrN addr0 in HH.
- by rewrite -HH.
- + by apply mx_eval_quot1.
- Qed.
-
- Lemma mx_eval_quot_is_injective :
- injective mx_eval_quot.
- Proof.
- intros x y.
- rewrite /mx_eval_quot -!eq_lock.
- rewrite -{2}[x]reprK -{2}[y]reprK.
- move: (repr x) (repr y) => {x} {y} x y eqq.
- apply/eqquotP.
- rewrite /Quotient.equiv/=.
- rewrite /mx_eval_ker_pred /in_mem/mem/=.
- apply/eqP.
- destruct mx_eval_is_rmorphism.
- specialize (base x y).
- simpl in base.
- by rewrite eqq addrN in base.
- Qed.
-
- Lemma poly_eval_mod (a b p : {poly R}) (c : R[i]) :
- a %% p = b %% p ->
- (map_poly RtoC p).[c] = 0 ->
- (map_poly RtoC a).[c] = (map_poly RtoC b).[c].
- Proof.
- intros.
- rewrite (divp_eq a p).
- rewrite (divp_eq b p) H.
- by rewrite !rmorphD !rmorphM !hornerD !hornerM H0 !mulr0 !add0r.
- Qed.
-
- Lemma mx_eval_is_surjective :
- let charvals := [seq characteristic_polynomial (vals 0 j) | j : 'I_n.+1] in
- pairwise (coprimep (R := R_fieldType)) charvals ->
- (forall j, Im (vals 0 j) != 0) ->
- forall (c : MR_comRingType 0 n),
- {x : {poly R} | mx_eval x = c}.
- Proof.
- intros charvals cop imn0 c.
-
- pose rvals := [seq sval (peval_C_decomp (vals 0 j) (c 0 j) (imn0 j)) | j : 'I_n.+1].
- have rvals_prop: forall (j:'I_n.+1), peval_C (rvals`_j) (vals 0 j) = (c 0 j).
- {
- subst rvals=>j.
- rewrite (nth_map 0)/=.
- - move: (svalP (peval_C_decomp (vals 0 (enum 'I_n.+1)`_j) (c 0 (enum 'I_n.+1)`_j) (imn0 (enum 'I_n.+1)`_j))).
- by rewrite !nth_ord_enum.
- - rewrite size_enum_ord.
- by destruct j.
- }
-
- have eqsize: size rvals = size charvals.
- {
- subst rvals charvals.
- by rewrite !size_map -enumT size_enum_ord.
- }
-
- generalize (chinesep_list_prop R_fieldType (zip rvals charvals)); intros.
- assert ([seq i.2 | i <- zip rvals charvals] = charvals).
- {
- apply (@eq_from_nth _ 0).
- - by rewrite size_map size_zip eqsize ?minnn.
- - intros.
- rewrite size_map in H.
- rewrite (nth_map 0) //.
- by rewrite nth_zip_cond H.
- }
- rewrite -H in cop.
- specialize (X cop).
- destruct X.
- exists x.
- apply /matrixP.
- intros ??.
- pose p := (zip rvals charvals)`_y.
- assert (p \in zip rvals charvals).
- {
- subst p.
- apply/(nthP 0).
- subst rvals charvals.
- exists y => //.
- by rewrite size_zip eqsize minnn size_map size_enum_ord ltn_ord.
- }
- specialize (e p H0).
- rewrite ord1 -(rvals_prop y).
- rewrite /mx_eval /peval_C.
- rewrite !mxE.
- have eqq1: charvals`_y = p.2.
- {
- subst p.
- by rewrite nth_zip.
- }
- have eqq2: rvals`_y = p.1.
- {
- subst p.
- by rewrite nth_zip.
- }
- rewrite eqq2.
- apply (poly_eval_mod _ _ _ _ e).
- rewrite -eqq1.
- assert (charvals`_y = characteristic_polynomial (vals 0 y)).
- {
- unfold charvals.
- rewrite (nth_map 0)/=.
- - by rewrite nth_ord_enum.
- - by rewrite size_enum_ord ltn_ord.
- }
- rewrite H1.
- apply characteristic_polynomial_correct.
- Qed.
-
- Lemma mx_eval_quot_is_surjective :
- let charvals := [seq characteristic_polynomial (vals 0 j) | j : 'I_n.+1] in
- pairwise (coprimep (R := R_fieldType)) charvals ->
- (forall j, Im (vals 0 j) != 0) ->
- forall (c : MR_comRingType 0 n),
- {x : mx_eval_ker_quot_ring | mx_eval_quot x = c}.
- Proof.
- intros.
- destruct (mx_eval_is_surjective H H0 c).
- exists (\pi_mx_eval_ker_quot_ring x).
- by rewrite pi_mx_eval_quot.
- Qed.
-
- Lemma mx_eval_quot_is_bijective :
- let charvals := [seq characteristic_polynomial (vals 0 j) | j : 'I_n.+1] in
- pairwise (coprimep (R := R_fieldType)) charvals ->
- (forall j, Im (vals 0 j) != 0) ->
- bijective mx_eval_quot.
- Proof.
- intros charvals cop imn0.
- pose g : MR_comRingType 0 n -> mx_eval_ker_quot_ring :=
- fun c => sval (mx_eval_quot_is_surjective cop imn0 c).
- apply Bijective with (g := g).
- - intros ?.
- assert (mx_eval_quot (g (mx_eval_quot x)) = mx_eval_quot x).
- {
- rewrite /g.
- destruct (mx_eval_quot_is_surjective cop imn0 (mx_eval_quot x)).
- by simpl.
- }
- by apply mx_eval_quot_is_injective in H.
- - intros ?.
- rewrite /g.
- destruct (mx_eval_quot_is_surjective cop imn0 x).
- by simpl.
- Qed.
-
-
-End eval_vectors.
-
-Definition odd_nth_roots' n :=
- \row_(j < (S (proj1_sig (pow2_S n))))
- (nth_root (2 * j + 1) (2 ^ (S n))).
-
-Lemma odd_nth_roots_minpoly n :
- forall i,
- root ('X^(2^n) + 1%:P) (odd_nth_roots n 0 i).
-Proof.
- move=> i.
- rewrite /odd_nth_roots mxE /root hornerD hornerXn hornerC.
- generalize (odd_roots_prim i n); intros.
- apply (f_equal (fun z => C1 + z)) in H.
- rewrite addrN in H.
- rewrite addrC in H.
- apply /eqP.
- unfold zero in *; simpl in *.
- unfold real_complex_def in *.
- unfold zero in *; simpl in *.
- by rewrite <- H.
-Qed.
-
-Lemma odd_nth_roots_minpoly_complex n :
- forall (c : R[i]),
- root ('X^(2^(S n)) + 1%:P) c ->
- Im c <> 0.
-Proof.
- intros.
- unfold root in H.
- rewrite hornerD hornerXn hornerC in H.
- move /eqP in H.
- unfold not; intros.
- destruct c.
- simpl in H0.
- rewrite H0 in H.
- assert (forall x:R, x^+2 + 1 <> 0).
- {
- intros.
- apply Rgt_not_eq.
- generalize (pow2_ge_0 x); intros.
- rewrite /one /add /zero /= -RpowE.
- coq_lra.
- }
- assert (forall x:R, x^+(2^(S n)) + 1 <> 0).
- {
- intros.
- by rewrite expnS mulnC exprM.
- }
- clear H0 H1.
- replace ((Re +i* 0)%C ^+ (2 ^ n.+1)) with (RtoC (Re ^+ (2 ^ n.+1))) in H.
- - unfold RtoC in H.
- move /eqP in H.
- rewrite eq_complex in H.
- move /andP in H.
- destruct H.
- simpl in H.
- move /eqP in H.
- by specialize (H2 Re).
- - assert (forall n,
- RtoC (Re ^+ n) = (Re +i* 0)%C ^+n).
- {
- induction n0.
- - by rewrite !expr0.
- - rewrite !exprS.
- assert (forall (x y : R),
- RtoC (x * y) = RtoC x * RtoC y).
- {
- apply RtoC_is_rmorphism.
- }
- by rewrite H0 IHn0.
- }
- apply H0.
- Qed.
-
-Lemma minpoly_mult_odd_nth_roots n (p : {poly R[i]}) :
- Pdiv.Ring.rmodp p ('X^(2^n) + 1%:P) = 0 ->
- forall i, root p (odd_nth_roots n 0 i).
-Proof.
-intros.
-move=> /Pdiv.Ring.rmodp_eq0P in H.
-rewrite Pdiv.ComRing.rdvdp_eq in H.
-destruct (pow2_S n).
-generalize (Xn_add_c_monic (R:=ComplexField.complex_comRingType R_fieldType) x 1); intros.
-rewrite monicE in H0.
-move=> /eqP in H0.
-rewrite -(eqP i0) in H0.
-rewrite H0 in H.
-rewrite Theory.expr1n in H.
-move=> /eqP in H.
-rewrite Theory.scale1r in H.
-rewrite H.
-unfold root.
-apply /eqP.
-rewrite hornerM.
-generalize (odd_nth_roots_minpoly n i); intros.
-unfold root in H1.
-move=> /eqP in H1.
-rewrite H1.
-rewrite mulr0 //.
-Qed.
-
-
-Lemma drop_poly_opp [S : ringType] n (p : {poly S}) :
- drop_poly n (- p) = - drop_poly n p.
-Proof.
- rewrite -!scaleN1r.
- apply drop_polyZ.
-Qed.
-
-Lemma drop_poly_diff [S : ringType] n (p q : {poly S}) :
- drop_poly n p = drop_poly n q ->
- seq.size (p - q) <= n.
-Proof.
- intros.
- assert (drop_poly n (p - q) = 0).
- {
- by rewrite drop_polyD drop_poly_opp -H addrN.
- }
- generalize (poly_take_drop n (p - q)); intro decomp.
- rewrite H0 mul0r addr0 in decomp.
- rewrite -decomp.
- apply size_take_poly.
-Qed.
-
-Lemma monic_size_pos [S : ringType] (p : {poly S}) :
- p \is monic ->
- seq.size p > 0.
-Proof.
- case: ltP => // n0lt.
- have: seq.size p == 0%nat by lia.
- rewrite size_poly_eq0 => eqq1 /monic_neq0.
- by rewrite eqq1.
-Qed.
-
-Lemma monic_drop_n_1 [S : ringType] n (p : {poly S}) :
- p \is monic ->
- seq.size p = n.+1 ->
- drop_poly n p = 1.
-Proof.
- rewrite monicE /lead_coef.
- rewrite /drop_poly poly_def => /[swap] -> /=.
- replace (n.+1 - n)%nat with 1%nat by lia.
- rewrite big_ord1 add0n => /eqP->.
- by rewrite expr0 alg_polyC.
-Qed.
-
-Lemma monic_dif_same_deg (p q : {poly R[i]}) :
- p \is monic ->
- q \is monic ->
- seq.size p = seq.size q ->
- seq.size (q - p) < seq.size p.
-Proof.
- intros.
- pose (n := seq.size p).
- pose (n1 := (n-1)%nat).
- assert (n = S n1).
- {
- generalize (monic_size_pos p H); lia.
- }
- unfold n in H2.
- rewrite H2.
- generalize (drop_poly_diff n1 q p); intros.
- apply H3.
- rewrite monic_drop_n_1; trivial.
- - rewrite monic_drop_n_1; trivial.
- - rewrite -H1 H2 //.
- Qed.
-
-Lemma monic_divides_same_deg (p q : {poly R[i]}) :
- p \is monic ->
- q \is monic ->
- seq.size p = seq.size q ->
- Pdiv.Ring.rdvdp (R:=ComplexField.complex_unitRingType R_realFieldType)
- p q ->
- p = q.
-Proof.
- intros.
- generalize (Pdiv.RingMonic.rdivp_eq H q); intros.
- generalize (monic_dif_same_deg _ _ H H0 H1); intros.
- generalize (Pdiv.RingMonic.redivp_eq H 1 H4); intros.
- rewrite mul1r (addrC q _) addrA addrN add0r in H5.
- assert (Pdiv.CommonRing.rdivp (R:=ComplexField.complex_ringType R_fieldType) q p = 1).
- {
- unfold Pdiv.CommonRing.rdivp.
- by rewrite H5 /=.
- }
- rewrite H6 mul1r in H3.
- unfold Pdiv.Ring.rdvdp in H2.
- move=> /eqP in H2.
- rewrite H2 addr0 in H3.
- by symmetry.
-Qed.
-
-Lemma seq_all_odd_roots n (p : {poly R[i]}) :
- let rs := MatrixFormula.seq_of_rV (odd_nth_roots n) in
- (forall i, root p (odd_nth_roots n 0 i)) ->
- seq.all (root p) (MatrixFormula.seq_of_rV (odd_nth_roots n)).
-Proof.
- intros.
- apply/tuple.all_tnthP => /= i.
- rewrite finfun.tnth_fgraph /= finfun.ffunE.
- by apply H.
-Qed.
-
-Lemma odd_roots_uniq' n :
- forall i j,
- odd_nth_roots n 0 i = odd_nth_roots n 0 j ->
- i = j.
-Proof.
- rewrite /odd_nth_roots => i j.
- rewrite !mxE.
- destruct (pow2_S (S n)) as [? eqq1].
- rewrite (eqP eqq1) => /nth_root_eq.
- rewrite -(eqP eqq1).
- rewrite !Nat.mod_small /=.
- - move/addIn => eqq2.
- apply/ord_inj.
- lia.
- - rewrite expnS.
- destruct i; destruct j; simpl in *.
- lia.
- - rewrite expnS.
- destruct i; destruct j; simpl in *.
- lia.
-Qed.
-
-Lemma odd_roots_uniq n :
- let rs := MatrixFormula.seq_of_rV (odd_nth_roots n) in
- uniq_roots (MatrixFormula.seq_of_rV (odd_nth_roots n)).
-Proof.
- rewrite uniq_rootsE /= /odd_nth_roots.
- rewrite /MatrixFormula.seq_of_rV.
- apply /tuple.tuple_uniqP => i j.
- rewrite !finfun.tnth_fgraph !finfun.ffunE => eqq1.
- apply odd_roots_uniq' in eqq1.
- by apply enum_val_inj.
-Qed.
-
-Lemma odd_nth_roots_minpoly_mult n (p : {poly R[i]}) :
- (forall i, root p (odd_nth_roots n 0 i)) ->
- Pdiv.Ring.rmodp p ('X^(2^n) + 1%:P) = 0 .
-Proof.
- intros roots.
- move: (seq_all_odd_roots n p roots) (odd_roots_uniq n) => allroot uniqroot.
- generalize (Pdiv.UnitRing.uniq_roots_rdvdp allroot uniqroot); intros.
- pose (rs := MatrixFormula.seq_of_rV (odd_nth_roots n)).
- assert (\prod_(z <- rs) ('X - z%:P) = 'X^(2 ^ n) + 1%:P).
- {
- apply monic_divides_same_deg.
- - apply monic_prod_XsubC.
- - destruct (pow2_S n).
- rewrite (eqP i).
- apply Xn_add_c_monic.
- - rewrite size_prod_XsubC.
- destruct (pow2_S n).
- rewrite (eqP i).
- rewrite size_Xn_addC.
- rewrite -(eqP i).
- by rewrite MatrixFormula.size_seq_of_rV.
- - generalize (seq_all_odd_roots n _ (odd_nth_roots_minpoly n)); intros allroot2.
- apply (Pdiv.UnitRing.uniq_roots_rdvdp allroot2 uniqroot).
- }
- rewrite H0 in H.
- unfold Pdiv.Ring.rdvdp in H.
- by move=> /eqP in H.
- Qed.
-
-Lemma map_RtoC_size (p : {poly R}) :
- seq.size p = seq.size (map_poly RtoC p).
-Proof.
- by rewrite (size_map_poly RtoC_rmorphism p).
-Qed.
-
-Lemma map_RtoC_lead_coef (p : {poly R}) :
- lead_coef (map_poly RtoC p) = RtoC (lead_coef p).
-Proof.
- by rewrite (lead_coef_map RtoC_rmorphism p).
-Qed.
-
-Lemma rmodp_RtoC_morph : {morph (map_poly RtoC) : p q / Pdiv.Ring.rmodp p q }.
-Proof.
- rewrite /Pdiv.Ring.rmodp => p q.
- by rewrite redivp_map.
-Qed.
-
-Lemma RtoC_inj : injective RtoC.
-Proof.
- rewrite /RtoC=>a b.
- by move=> [].
-Qed.
-
-Lemma map_poly_RtoC_eq0E p : map_poly RtoC p = 0 <-> (p = 0).
-Proof.
- split.
- - rewrite -(map_poly0 RtoC).
- move/map_inj_poly.
- apply => //.
- by apply RtoC_inj.
- - move=> ->.
- by rewrite map_poly0.
-Qed.
-
-Lemma rmodp_R (p q : {poly R}) :
- Pdiv.Ring.rmodp p q = 0 <-> Pdiv.Ring.rmodp (map_poly RtoC p) (map_poly RtoC q) = 0.
-Proof.
- rewrite -rmodp_RtoC_morph.
- symmetry.
- apply map_poly_RtoC_eq0E.
-Qed.
-
-Lemma map_poly_add_RtoC (p q : {poly R}) :
- map_poly RtoC (p + q) = (map_poly RtoC p) + (map_poly RtoC q).
-Proof.
- apply (raddfD map_RtoC_rmorphism).
-Qed.
-
-Lemma map_RtoC_Xnpoly n :
- map_poly (aR:=R_ringType) (rR:=ComplexField.complex_ringType R_fieldType) RtoC
- ('X^(2 ^ n) + 1%:P) = 'X^(2 ^ n) + 1%:P.
-Proof.
- rewrite map_poly_add_RtoC map_polyXn.
- f_equal.
- by rewrite (map_polyC RtoC_rmorphism 1).
-Qed.
-
-Lemma odd_nth_roots_minpoly_mult_R n (p : {poly R}) :
- (forall i, root (map_poly RtoC p) (odd_nth_roots n 0 i)) ->
- Pdiv.Ring.rmodp p ('X^(2^n) + 1%:P) = 0.
-Proof.
- intros.
- generalize (odd_nth_roots_minpoly_mult n (map_poly RtoC p) H); intros.
- rewrite rmodp_R.
- rewrite <- H0.
- f_equal.
- apply map_RtoC_Xnpoly.
-Qed.
-
-Lemma minpoly_mult_odd_nth_roots_R n (p : {poly R}) :
- Pdiv.Ring.rmodp p ('X^(2^n) + 1%:P) = 0 ->
- forall i, root (map_poly RtoC p) (odd_nth_roots n 0 i).
-Proof.
- intros.
- rewrite rmodp_R in H.
- rewrite map_RtoC_Xnpoly in H.
- by generalize (minpoly_mult_odd_nth_roots n (map_poly RtoC p) H).
-Qed.
-
-Lemma odd_nth_roots_quot n (p : {poly R}) :
- mx_eval_ker_pred (odd_nth_roots' n) p <->
- princ_ideal_pred ('X^(2^n) + 1%:P) p.
-Proof.
- unfold mx_eval_ker_pred, princ_ideal_pred.
- split; intros.
- - apply /eqP.
- apply odd_nth_roots_minpoly_mult_R.
- intros.
- move=> /eqP in H.
- apply matrixP in H.
- unfold odd_nth_roots' in H.
- unfold odd_nth_roots.
- unfold root.
- destruct i.
- assert (m < (sval (pow2_S n)).+1) by (simpl; lia).
- specialize (H 0 (Ordinal H0)).
- simpl in H.
- rewrite !mxE in H.
- rewrite !mxE.
- replace (2 * Ordinal (n:=2 ^ n) (m:=m) i + 1)%N with
- (2 * Ordinal (n:=(2 ^ n - 1).+1) (m:=m) H0 + 1)%N.
- + by rewrite H.
- + f_equal.
- - unfold zero; simpl.
- unfold const_mx.
- apply /eqP.
- apply matrixP.
- intros ??.
- rewrite !mxE.
- move=> /eqP in H.
- generalize (minpoly_mult_odd_nth_roots_R n p H); intros.
- unfold pow2_S in y.
- destruct y as [y' ?].
- assert (y' < 2^n) by lia.
- specialize (H0 (Ordinal H1)).
- unfold root in H0.
- move=> /eqP in H0.
- simpl in H0.
- unfold odd_nth_roots in H0.
- rewrite mxE in H0.
- apply H0.
- Qed.
-
-
-Section norms.
-
- Import matrix_ring.
-
- Implicit Types x y : R[i].
-
- Notation normc := ComplexField.Normc.normc.
-
- Definition norm1 {n} (v : 'rV[R[i]]_n):R := \sum_(j < n) normc (v 0 j).
-
- Definition norm_inf {n} (v : 'rV[R[i]]_n):R := \big[Order.max/0]_(j < n) normc (v 0 j).
-
- Definition cvec_norm_inf {n} (v : 'cV[R[i]]_n):R := norm_inf(v^T).
- Definition cvec_norm1 {n} (v : 'cV[R[i]]_n):R := norm1(v^T).
-
- Definition matrix_norm_inf {n m} (mat : 'M[R[i]]_(n,m)) :=
- \big[Order.max/0]_(j i _.
- rewrite (bigD1 i) // /= big1_idem ?addr0 //.
- - by rewrite /diag_mx mxE eq_refl /= mulr1n.
- - intros.
- rewrite mxE eq_sym.
- by rewrite (negbTE H) ComplexField.Normc.normc0.
- Qed.
-
- Lemma normc_nneg (x : R[i]) :
- (R0 <= normc x)%O.
- Proof.
- rewrite /normc.
- case: x => r i.
- apply ssrnum.Num.Theory.sqrtr_ge0.
- Qed.
-
- Lemma sum_normc_nneg {n} (v : 'rV[R[i]]_n) :
- ((@zero R_zmodType) <= \sum_(k < n) normc (R:=R_rcfType) (v ord0 k))%O.
- Proof.
- apply big_rec => // i x _ xnneg.
- by rewrite ssrnum.Num.Theory.addr_ge0 // normc_nneg.
- Qed.
-
- Lemma mat_vec_norm_inf {n} (v : 'rV[R[i]]_n) :
- norm_inf v = matrix_norm_inf (v^T).
- Proof.
- rewrite /norm_inf /matrix_norm_inf /=.
- apply eq_bigr => j _.
- by rewrite big_ord_recl big_ord0 addr0 mxE.
- Qed.
-
- Lemma mat_vec_norm1 {n} (v : 'rV[R[i]]_n) :
- norm1 v = matrix_norm1 (v^T).
- Proof.
- rewrite /norm1 /matrix_norm1 /matrix_norm_inf /=.
- rewrite big_ord_recl big_ord0 Order.POrderTheory.max_l ?sum_normc_nneg //.
- by apply eq_bigr => j _; rewrite !mxE.
- Qed.
-
-
- Lemma R00 : R0 = 0.
- Proof.
- by [].
- Qed.
-
-
- Lemma normc_nnegR (x : R[i]) :
- Rle R0 (normc x).
- Proof.
- move: (normc_nneg x).
- rewrite /Order.le /=.
- by move/RlebP.
- Qed.
-
- Lemma maxrM (c a b : R) : Rle 0 c -> Order.max (c * a) (c * b) = c * Order.max a b.
- Proof.
- rewrite /Order.max /Order.lt /=.
- (repeat case: RltbP); [lra | | | lra]; intros.
- - destruct H.
- + apply Rmult_lt_reg_l in p => //.
- + subst.
- by rewrite !mul0r.
- - destruct H.
- + elim n.
- by apply Rmult_lt_compat_l.
- + subst.
- by rewrite !mul0r.
- Qed.
-
- Lemma maxrM_l (c a b : R) : Rle 0 c -> Order.max (a * c) (b * c) = (Order.max a b)*c.
- Proof.
- rewrite /Order.max /Order.lt /=.
- (repeat case: RltbP); [lra | | | lra]; intros.
- - destruct H.
- + apply Rmult_lt_reg_r in p => //.
- + subst.
- by rewrite !mulr0.
- - destruct H.
- + elim n.
- by apply Rmult_lt_compat_r.
- + subst.
- by rewrite !mulr0.
- Qed.
-
- Lemma norm_infZ {n} (c : R[i]) (v : 'rV[R[i]]_n) :
- norm_inf (c *: v) = (normc c) * norm_inf v.
- Proof.
- rewrite /norm_inf.
- apply (big_rec2 (fun a b => a = normc c * b)).
- - by rewrite mulr0.
- - move=> i a b _ ->.
- rewrite mxE ComplexField.Normc.normcM maxrM //.
- apply normc_nnegR.
- Qed.
-
- Lemma norm_inf_nneg {n} (v : 'rV[R[i]]_n) :
- (@zero R_ringType <= norm_inf v)%O.
- Proof.
- rewrite /norm_inf.
- apply big_rec => //= i x _ xn.
- by rewrite Order.TotalTheory.le_maxr xn orbT.
- Qed.
-
- Lemma normc_triang (x y : R[i]) :
- (normc (x + y) <= normc x + normc y)%O.
- Proof.
- generalize (ComplexField.lec_normD x y); intros.
- rewrite /ComplexField.lec /= addr0 in H.
- move => /andP in H.
- by destruct H.
- Qed.
-
- Lemma normc_triangR (x y : R[i]) :
- Rle (normc (x + y)) (normc x + normc y).
- Proof.
- move: (normc_triang x y) => /=.
- rewrite /Order.le /Order.POrder.le /=.
- by move/RleP.
- Qed.
-
- Lemma normc_triang_sum {n} (a : 'I_n -> R[i]) :
- Rleb (normc (\sum_(j R) (c : R) :
- (\sum_(j R) (c : R) :
- c * (\sum_(j R) (c : R) :
- Rle 0 c ->
- (\big[Order.max/0]_(j R) (c : R) :
- Rle 0 c ->
- c * (\big[Order.max/0]_(j R) :
- (forall j, Rleb (a j) (b j)) ->
- Rleb (\sum_(j
- Rleb c d ->
- Rleb (Order.max a c) (Order.max b d).
- Proof.
- rewrite -!RmaxE.
- intros.
- apply /RlebP.
- move /RlebP in H.
- move /RlebP in H0.
- apply Rmax_case; apply Rmax_Rle.
- - by left.
- - by right.
- Qed.
-
- Lemma bigmax_le2 {n} (a b : 'I_n -> R) :
- (forall j, Rleb (a j) (b j)) ->
- Rleb (\big[Order.max/0]_(jR):
- Order.le init (\big[Order.max/init]_(j < n) f (v 0 j)).
- Proof.
- rewrite BigOp.bigopE.
- unlock reducebig.
- elim: (index_enum _) => /=.
- - by exact: Order.POrderTheory.lexx.
- - move=> a l.
- rewrite Order.TotalTheory.le_maxr => ->.
- by rewrite orbT.
- Qed.
-
- Lemma omax_l {disp : Datatypes.unit} {T : porderType disp} (x y:T) : Order.le x (Order.max x y).
- Proof.
- rewrite Order.POrderTheory.maxEle.
- by case_eq (Order.le x y).
- Qed.
-
- Lemma omax_r {disp : Datatypes.unit} {T : orderType disp} (x y:T) : Order.le y (Order.max x y).
- Proof.
- rewrite Order.POrderTheory.maxElt.
- case_eq (Order.lt x y) => //.
- rewrite (Order.TotalTheory.leNgt y x).
- by move/negbT.
- Qed.
-
- Lemma omax_l_real (x y:R) : Rleb x (Order.max x y).
- Proof.
- by exact: (omax_l x y).
- Qed.
-
- Lemma omax_r_real (x y:R) : Rleb y (Order.max x y).
- Proof.
- by exact: (omax_r x y).
- Qed.
-
- Lemma bigmaxr_le {n} (v : 'rV[R[i]]_n) init f i:
- Rleb (f (v 0 i)) (\big[Order.max/init]_(j < n) f (v 0 j)).
- Proof.
- rewrite BigOp.bigopE.
- unlock reducebig.
- move: (mem_index_enum i).
- elim: (index_enum _) => /= [| a l IHl].
- - by rewrite seq.in_nil.
- - rewrite seq.in_cons => /orP [/eqP->| ].
- + by exact: omax_l_real.
- + move=> inn.
- move: (IHl inn) => /RlebP=>le1.
- apply/RlebP.
- eapply Rle_trans; try apply le1.
- apply/RlebP.
- apply omax_r_real.
- Qed.
-
- Lemma bigmaxr_le_alt {n} (v : 'I_n -> R) init i:
- Rleb (v i) (\big[Order.max/init]_(j < n) (v j)).
- Proof.
- rewrite BigOp.bigopE.
- unlock reducebig.
- move: (mem_index_enum i).
- elim: (index_enum _) => /= [| a l IHl].
- - by rewrite seq.in_nil.
- - rewrite seq.in_cons => /orP [/eqP->| ].
- + by exact: omax_l_real.
- + move=> inn.
- move: (IHl inn) => /RlebP=>le1.
- apply/RlebP.
- eapply Rle_trans; try apply le1.
- apply/RlebP.
- apply omax_r_real.
- Qed.
-
- Lemma mat_vec_norm_bound1 {n m}
- (mat : 'M[R[i]]_(n, m))
- (vec : 'rV[R[i]]_m) k:
- Rleb (normc (\sum_(j j.
- rewrite ComplexField.Normc.normcM.
- apply /RlebP.
- apply Rmult_le_compat_l.
- + apply normc_nnegR.
- + apply /RlebP.
- apply bigmaxr_le.
- Qed.
-
-
- Lemma bigmax_nneg {n} (v : 'I_n -> R) :
- (forall i, Rleb 0 (v i)) ->
- Rleb 0 (\big[Order.max/0]_(j < n) (v j)).
- Proof.
- intros.
- apply big_rec.
- - apply /RlebP.
- unfold zero; simpl.
- coq_lra.
- - intros.
- assert ((zero R_ringType) <= Order.max (v i) x)%O.
- {
- rewrite Order.TotalTheory.le_maxr.
- apply /orP.
- left.
- apply H.
- }
- by rewrite /Order.le /= in H1.
- Qed.
-
- Lemma bigmax_normc_nneg {n} (v : 'rV[R[i]]_n):
- Rleb 0 (\big[Order.max/0]_(j < n) normc (v 0 j)).
- Proof.
- apply bigmax_nneg => i.
- apply normc_nneg.
- Qed.
-
- Lemma sum_nneg {n} (v : 'I_n -> R) :
- (forall i, Rleb 0 (v i)) ->
- Rleb 0 (\sum_(j < n) (v j)).
- Proof.
- intros.
- apply big_rec.
- - apply /RlebP.
- unfold zero; simpl.
- coq_lra.
- - intros.
- apply /RlebP.
- rewrite /add /=.
- apply Rplus_le_le_0_compat; by apply /RlebP.
- Qed.
-
- Lemma matrix_vec_norm_inf_sub_mult {n m}
- (mat : 'M[R[i]]_(n, m))
- (vec : 'rV[R[i]]_m) :
- Rleb (matrix_norm_inf (mat *m (vec^T)))
- ((matrix_norm_inf mat) * (norm_inf vec)).
- Proof.
- rewrite /matrix_norm_inf /norm_inf /mulmx /=.
- generalize (@max_mult_distr n); intros.
- rewrite /mul /= in H.
- rewrite H.
- - apply bigmax_le2.
- intros.
- rewrite big_ord_recl big_ord0 addr0 mxE /trmx /=.
- under eq_bigr do rewrite mxE.
- by rewrite mat_vec_norm_bound1.
- - apply /RlebP.
- apply bigmax_normc_nneg.
- Qed.
-
- Lemma sum_plus {n} (a b : 'I_n -> R) :
- \sum_(i 'I_m -> R) :
- \sum_(i j.
- under eq_bigr do rewrite mxE.
- apply /RlebP.
- apply Rle_trans with
- (r2 := \sum_(i < p) (\sum_(j0 < m) normc (R:=R_rcfType) (mat1 j j0 * mat2 j0 i))).
- + apply /RlebP.
- apply sum_le => j0.
- apply normc_triang_sum.
- + rewrite exchange_sums.
- generalize (@sum_mult_distr); intros.
- rewrite /mul /= in H0.
- rewrite H0.
- apply /RlebP.
- apply sum_le => j0.
- replace (\sum_(i < p) normc (R:=R_rcfType) (mat1 j j0 * mat2 j0 i)) with
- ((normc (R:=R_rcfType) (mat1 j j0)) *
- (\sum_(i < p) normc (R:=R_rcfType) (mat2 j0 i))).
- * apply /RlebP.
- apply Rmult_le_compat_l.
- -- apply normc_nnegR.
- -- apply /RlebP.
- generalize (@bigmaxr_le_alt m
- (fun j0 => (\sum_(i < p) normc (R:=R_rcfType) (mat2 j0 i))) 0 j0); intros.
- apply H1.
- * rewrite mulrC sum_mult_distr.
- apply eq_bigr => i _.
- by rewrite mulrC ComplexField.Normc.normcM.
- - apply /RlebP.
- apply bigmax_nneg => i.
- apply sum_nneg => k.
- apply normc_nneg.
- Qed.
-
- Lemma matrix_norm_inf_scale {n m} (mat : 'M[R[i]]_(n,m)) (c : R[i]) :
- matrix_norm_inf (scalemx c mat) = (normc c)*(matrix_norm_inf mat).
- Proof.
- rewrite /matrix_norm_inf /scalemx max_mult_distl.
- - apply eq_bigr => j _.
- rewrite sum_mult_distl.
- apply eq_bigr => k _.
- by rewrite mxE ComplexField.Normc.normcM.
- - apply normc_nnegR.
- Qed.
-
- Lemma big_max_const (n:nat) (c : R) : n != 0%nat ->
- Rle 0 c ->
- \big[Order.max/0]_(j < n) c = c.
- Proof.
- case: n; [lia |intros n _].
- induction n.
- - rewrite big_ord_recl big_ord0 => H.
- rewrite Order.POrderTheory.max_l //.
- apply /RlebP.
- apply H.
- - rewrite big_ord_recl => H.
- rewrite IHn.
- rewrite Order.POrderTheory.max_l //.
- apply H.
- Qed.
-
- Lemma normc_conj (x : R[i]) :
- ComplexField.Normc.normc x = ComplexField.Normc.normc (conjc x).
- Proof.
- case: x => rx ix /=.
- by rewrite sqrrN.
- Qed.
-
- Lemma normc_nth_root j (n:nat) :
- n != 0%nat ->
- normc (nth_root j n) = 1.
- Proof.
- rewrite /normc /nth_root => nN0.
- rewrite -!RpowE -!Rsqr_pow2 addrC /add /=.
- by rewrite sin2_cos2 ssrnum.Num.Theory.sqrtr1.
- Qed.
-
- Lemma big_max_const_fun (n : nat) (a : 'I_n -> R) (c : R) : n != 0%nat ->
- Rle 0 c ->
- (forall i, a i = c) ->
- \big[Order.max/0]_(i < n) (a i) = c.
- Proof.
- intros.
- under eq_bigr do rewrite H1.
- by apply big_max_const.
- Qed.
-
- Lemma norm_inf_const_norm (n : nat) (vec : 'rV[R[i]]_n.+1) :
- (forall i , normc (R:=R_rcfType) (vec 0 i) = 1) ->
- norm_inf vec = 1.
- Proof.
- intros.
- rewrite /norm_inf.
- apply big_max_const_fun; trivial.
- rewrite /one /=; coq_lra.
- Qed.
-
- Lemma pow2n0 n : (2 ^ n)%N != 0%N.
- Proof.
- by rewrite expn_eq0.
- Qed.
-
- Hint Immediate pow2n0.
-
- Lemma norm_inf_conj_half_roots (n : nat) :
- norm_inf (map_mx conjc (nth_roots_half n.+1)) = 1.
- Proof.
- rewrite /norm_inf /nth_roots_half.
- apply big_max_const_fun.
- - apply pow2n0.
- - rewrite /one/=; coq_lra.
- - intros.
- by rewrite !mxE -normc_conj normc_nth_root // pow2n0.
- Qed.
-
- Lemma big_sum_const (n : nat) (c : R) :
- \sum_(j i0 _.
- rewrite /even_nth_roots !mxE -exp_conj -normc_conj pow_nth_root' ?normc_nth_root.
- over.
- apply pow2n0.
- apply pow2n0.
- by rewrite big_sum_const.
- Qed.
-
- Lemma encode_mat_norm_inf (n : nat) :
- let pmat := peval_mat (odd_nth_roots (S n)) in
- let encmat := (conj_mat (pmat^T)) in
- Rleb (matrix_norm_inf encmat) (2^S n)%:R.
- Proof.
- rewrite /= encode_mat_prod.
- apply /RlebP.
- eapply Rle_trans.
- - apply /RlebP.
- apply matrix_norm_inf_sub_mult.
- - rewrite -norm_inf_diag norm_inf_conj_half_roots Rmult_1_l.
- rewrite norm_inf_peval_mat_conj_even_roots.
- coq_lra.
- Qed.
-
-
- Lemma norm_inf_triang {n} (v1 v2 : 'rV[R[i]]_n) :
- (norm_inf (v1 + v2) <= norm_inf v1 + norm_inf v2)%O.
- Proof.
- rewrite /norm_inf.
- apply big_rec.
- - unfold Order.le; simpl.
- apply /RlebP.
- erewrite <- addr0 at 1.
- apply Rplus_le_compat; apply /RlebP; apply bigmax_normc_nneg.
- - intros i x _ xn.
- rewrite Order.TotalTheory.le_maxl.
- generalize (normc_triang (v1 0 i) (v2 0 i)); intros.
- rewrite mxE.
- apply /andP; split; trivial.
- unfold Order.le in *; simpl in *.
- apply /RlebP.
- move => /RlebP in H.
- eapply Rle_trans.
- apply H.
- apply Rplus_le_compat; apply /RlebP; apply bigmaxr_le.
- Qed.
-
- Lemma norm_inf_semi_multiplicative {n} (v1 v2 : 'rV[R[i]]_n) :
- (norm_inf (map2_mx (fun (a b : R[i]) => a * b) v1 v2) <= norm_inf v1 * norm_inf v2)%O.
- Proof.
- rewrite /norm_inf.
- apply big_rec.
- - unfold Order.le, zero, mul; simpl.
- apply /RlebP; apply Rmult_le_pos; apply /RlebP; apply bigmax_normc_nneg.
- - intros i x _ xn.
- rewrite Order.TotalTheory.le_maxl mxE ComplexField.Normc.normcM.
- apply /andP; split; trivial.
- clear x xn.
- unfold Order.le; simpl.
- apply /RlebP.
- apply Rmult_le_compat; try apply normc_nnegR; apply /RlebP; apply bigmaxr_le.
- Qed.
-
- Lemma norm_inf_pos_def {n} (v : 'rV[R[i]]_n) :
- norm_inf v = 0 -> v = 0.
- Proof.
- rewrite /norm_inf => HH.
- apply /matrixP => a b.
- move: (ord1 a)->.
- move: (bigmaxr_le v 0 (@normc _) b).
- rewrite {}HH.
- move/RlebP => HH.
- rewrite [v 0 b]ComplexField.eq0_normC ?mxE//.
- move: (normc_nnegR (v 0 b)) => HH2.
- rewrite /zero /=.
- f_equal.
- now apply Rle_antisym.
- Qed.
-
- Lemma normc_Rabs (r : R) :
- normc (RtoC r) = Rabs r.
- Proof.
- rewrite /normc /RtoC (expr2 0) mulr0 addr0.
- by rewrite ssrnum.Num.Theory.sqrtr_sqr.
- Qed.
-
- Lemma mx_evalZ {n} (v : 'rV[R[i]]_n.+1) (r:R) p :
- mx_eval v (r *: p) = (RtoC r) *: (mx_eval v p).
- Proof.
- apply matrixP => a b.
- rewrite !mxE /scale /= scale_polyE.
- rewrite rmorphismMP /= (map_polyC RtoC_rmorphism) /=.
- by rewrite -hornerZ /= /scale /= scale_polyE.
- Qed.
-
- (* following 4 lemmas show canon_norm_inf is a norm on quotient by x^+(2^n) + 1 *)
- Lemma canon_norm_infZ n (r : R) (p : {poly R}) :
- canon_norm_inf n (r *: p) = Rabs r * canon_norm_inf n p.
- Proof.
- by rewrite /canon_norm_inf !mx_evalZ norm_infZ normc_Rabs.
- Qed.
-
- Lemma canon_norm_inf_nneg n (p : {poly R}) :
- (zero R_ringType <= canon_norm_inf n p)%O.
- Proof.
- apply norm_inf_nneg.
- Qed.
-
- Lemma canon_norm_inf_triang n (p q : {poly R}) :
- (canon_norm_inf n (p + q) <= canon_norm_inf n p + canon_norm_inf n q)%O.
- Proof.
- rewrite /canon_norm_inf.
- move: (raddfD (mx_eval_rmorphism (odd_nth_roots' n)) p q) => /= ->.
- apply norm_inf_triang.
- Qed.
-
- Lemma canon_norm_inf_semi_multiplicative n (p q : {poly R}) :
- (canon_norm_inf n (p * q) <= canon_norm_inf n p * canon_norm_inf n q)%O.
- Proof.
- rewrite /canon_norm_inf rmorphM.
- apply norm_inf_semi_multiplicative.
- Qed.
-
- Lemma mx_eval_pow {n} (v : 'rV_n.+1) e :
- mx_eval v 'X^e = map_mx (fun c => c ^ e) v.
- Proof.
- apply eq_map_mx.
- intros ?.
- by rewrite map_polyXn hornerXn /exprz.
- Qed.
-
-
- Lemma iter_max {disp : Datatypes.unit} {T : porderType disp} i (a b: T) :
- i != 0%nat ->
- ssrnat.iter i (Order.max a) b = Order.max a b.
- Proof.
- induction i => //=.
- destruct i => // _.
- rewrite IHi //=.
- by rewrite Order.POrderTheory.max_maxxK.
- Qed.
-
- Lemma canon_norm_inf_pow n e :
- canon_norm_inf n 'X^e = 1.
- Proof.
- rewrite /canon_norm_inf mx_eval_pow /odd_nth_roots'.
- unfold norm_inf.
- under eq_big_seq.
- {
- intros.
- rewrite !mxE.
- rewrite /exprz.
- rewrite pow_nth_root'; try lia.
- rewrite normc_nth_root; try lia.
- over.
- }
- simpl.
- rewrite big_const_ord iter_max //.
- rewrite /Order.max.
- case: RltP => //.
- rewrite /one/zero/=.
- move/lt_IZR.
- lia.
- Qed.
-
- Lemma canon_norm_inf_C_pow n e c :
- canon_norm_inf n (c *: 'X^e) = Rabs c.
- Proof.
- by rewrite canon_norm_infZ canon_norm_inf_pow mulr1.
- Qed.
-
- Lemma canon_norm_inf_C n (c : R) :
- canon_norm_inf n (polyC c) = Rabs c.
- Proof.
- rewrite -(canon_norm_inf_C_pow n 0 c).
- f_equal.
- by rewrite expr0 /scale /Lmodule.scale /= scale_polyE mulr1.
- Qed.
-
- Lemma coef_norm1_C (c : R) :
- coef_norm1 (polyC c) = Rabs c.
- Proof.
- rewrite /coef_norm1 size_polyC.
- case: eqVneq; simpl.
- - rewrite big_ord0 => ->.
- by rewrite Rabs_R0.
- - by rewrite big_ord1 coefC.
- Qed.
-
- Lemma size_poly_def (p : {poly R}) :
- size (\sum_(i < size p) p`_i *: 'X^i) = size p.
- Proof.
- case: (eqVneq p 0); intros.
- - by rewrite e poly_size_0 big_ord0 poly_size_0.
- - rewrite -poly_def size_poly_eq //.
- by rewrite -lead_coefE lead_coef_eq0.
- Qed.
-
- Lemma coef_norm1_poly_def (p : {poly R}) :
- coef_norm1 (\sum_(i < (size p)) p`_i *: 'X^i) =
- \sum_(i < (size p)) Rabs p`_i.
- Proof.
- rewrite /coef_norm1 size_poly_def.
- apply eq_big => //= i _.
- f_equal.
- by rewrite -poly_def coef_poly ltn_ord.
- Qed.
-
- Lemma canon_norm_inf_poly_def n m (p : {poly R}) :
- (canon_norm_inf n (\sum_(i < m) p`_i *: 'X^i) <=
- \sum_(i < m) canon_norm_inf n (p`_i *: 'X^i))%O.
- Proof.
- induction m.
- - rewrite !big_ord0 canon_norm_inf_C Rabs_R0.
- apply Order.POrderTheory.le_refl.
- -
- move: (@big_nat_recr {poly R} 0 (add_monoid _)
- m
- 0
- (fun i => p`_i *: 'X^i)
- (leq0n _)
- ).
- rewrite !big_mkord => ->.
- move: (@big_nat_recr R 0 (@add_monoid _)
- m
- 0
- (fun i => canon_norm_inf n (p`_i *: 'X^i))
- (leq0n _)
- ).
- rewrite !big_mkord => -> /=.
- eapply Order.POrderTheory.le_trans.
- + apply canon_norm_inf_triang.
- + by apply ssrnum.Num.Theory.lerD.
- Qed.
-
- Lemma canon_norm_inf_le_norm1 n (p : {poly R}) :
- (canon_norm_inf n p <= coef_norm1 p)%O.
- Proof.
- rewrite -(coefK p) poly_def coef_norm1_poly_def.
- eapply Order.POrderTheory.le_trans.
- apply canon_norm_inf_poly_def.
- rewrite /Order.le /Order.POrder.le /=.
- apply /RlebP.
- right.
- apply eq_big_seq => ??.
- by rewrite canon_norm_inf_C_pow.
- Qed.
-
- Lemma canon_norm_inf_val n (p : {poly R}) (i : 'I_(2^n-1).+1) :
- (normc ((map_poly RtoC p).[odd_nth_roots' n 0 i]) <= canon_norm_inf n p)%O.
- Proof.
- rewrite /canon_norm_inf /norm_inf /=.
- generalize (@bigmaxr_le (2^n-1).+1 (odd_nth_roots' n) 0 (fun c => normc (map_poly RtoC p).[c]) i) => HH.
- eapply Order.POrderTheory.le_trans.
- - rewrite /Order.le/=.
- apply HH.
- - rewrite /Order.le/=.
- apply /RlebP.
- right.
- apply eq_big_seq => ??.
- f_equal.
- by rewrite /mx_eval !mxE.
- Qed.
-
- Lemma peval_mx_eval {n} (v :'rV[R[i]]_n.+1) (p : {poly R}) :
- size p <= n.+1 ->
- let pmat := peval_mat v in
- let pC := map_poly RtoC p in
- mx_eval v p = (pmat *m (poly_rV pC)^T)^T.
- Proof.
- intros.
- rewrite /mx_eval /peval_mat /map_mx trmx_mul trmxK /poly_rV /= /mulmx.
- apply matrixP => j k.
- rewrite !mxE /= horner_coef -map_RtoC_size.
- unfold pmat, peval_mat, pC.
- under [RHS]eq_bigr do rewrite !mxE.
- case : (eqVneq (size p) n.+1); intros.
- - rewrite e /=.
- apply eq_big_seq => ??.
- by rewrite fintype.ord1.
- - transitivity (\sum_(i0 < (size p + (n.+1-size p))%nat)
- (map_poly (aR:=R_ringType) (rR:=ComplexField.complex_ringType R_fieldType) RtoC p)`_i0 * v 0 k ^+ i0);
- [| by have ->: (size p + (n.+1 - size p) = n.+1)%nat by lia].
- rewrite big_split_ord /=.
- rewrite !fintype.ord1.
- suff ->: ( \sum_(i0 < n.+1 - size p)
- (map_poly (aR:=R_ringType) (rR:=ComplexField.complex_ringType R_fieldType) RtoC p)`_
- (size p + i0) * v 0 k ^+ (size p + i0) = 0).
- { by rewrite addr0. }
-
- under eq_bigr => si _.
- { rewrite nth_default ?mul0r.
- - over.
- - rewrite -map_RtoC_size.
- lia.
- }
- by rewrite big_const_seq iter_addr_0 mul0rn.
- Qed.
-
-Lemma decode_encode_scalar_mx' (n : nat):
- let pmat := (peval_mat (odd_nth_roots' (S n))) in
- pmat *m (conj_mat (pmat^T)) = scalar_mx (2^S n)%:R.
-Proof.
- Proof.
- apply/matrixP.
- move/matrixP: (decode_encode_scalar_mx n) => H i j.
- have eqq: (sval (pow2_S n.+1)).+1 = (2^n.+1)%nat by (simpl; lia).
- move: (H (cast_ord eqq i) (cast_ord eqq j)).
- rewrite !mxE /= => <-.
- rewrite (big_ord_widen_leq (2^n.+1)%N); try lia.
- apply eq_big.
- - move=> [??] /=.
- by lia.
- - intros.
- by rewrite !mxE inordK.
- Qed.
-
- Lemma encode_mat_norm_inf' (n : nat) :
- let pmat' := peval_mat (odd_nth_roots' (S n)) in
- let encmat' := (conj_mat (pmat'^T)) in
- Rleb (matrix_norm_inf encmat') (2^S n)%:R.
- Proof.
- generalize (encode_mat_norm_inf n); intros.
- unfold matrix_norm_inf in *; simpl in *.
- rewrite (big_ord_widen_leq (2^n.+1)%N); try lia.
- apply /RlebP.
- move /RlebP in H.
- eapply Rle_trans; cycle 1.
- apply H.
- right.
- apply eq_big.
- - intros ?.
- destruct x.
- simpl.
- lia.
- - intros.
- rewrite (big_ord_widen_leq (2^n.+1)%N); try lia.
- apply eq_big.
- + intros ?.
- destruct x.
- simpl.
- lia.
- + intros.
- rewrite /odd_nth_roots /odd_nth_roots' !mxE !inordK //.
- Qed.
-
-
- Lemma encmat_pmat (n : nat) :
- let pmat' := peval_mat (odd_nth_roots' (S n)) in
- let encmat := (conj_mat (pmat'^T)) in
- pmat' *m ((RtoC (inv (2^S n)%:R)) *: encmat) = scalar_mx 1.
- Proof.
- intros.
- rewrite -scalemxAr /encmat /pmat' decode_encode_scalar_mx'.
- apply /matrixP => i j.
- rewrite !mxE mulrnAr -RtoCR -rmorphM /= mulrC divrr // unitfE.
- by rewrite natmul0eq pow2n0.
- Qed.
-
- Lemma invmx_comm (n : nat) (A B : 'M[R[i]]_n) :
- A *m B = scalar_mx 1 ->
- B *m A = scalar_mx 1.
- Proof.
- intros.
- destruct (mulmx1_unit H).
- generalize (mulmxV H1); intros.
- generalize H2; intros.
- apply (f_equal (fun z => A *m z)) in H2.
- rewrite mulmxA H mul1mx mulmx1 in H2.
- by rewrite H2 in H3.
- Qed.
-
- Lemma scalar_prod_comm (n : nat) (A B : 'M[R[i]]_n) (c : R[i]) :
- c != 0 ->
- A *m B = scalar_mx c ->
- B *m A = scalar_mx c.
- Proof.
- intros.
- apply (f_equal (fun z => (inv c) *: z)) in H0.
- rewrite scalemxAl scale_scalar_mx mulrC divff // in H0.
- apply invmx_comm in H0.
- rewrite -scalemxAr in H0.
- apply (f_equal (fun z => c *: z)) in H0.
- rewrite scalerA divff // scalemx1 in H0.
- by rewrite -(scale1mx (B *m A)) -H0.
- Qed.
-
- Lemma encmat_pmat_alt (n : nat) :
- let pmat' := peval_mat (odd_nth_roots' (S n)) in
- let encmat := (conj_mat (pmat'^T)) in
- ((RtoC (inv (2^S n)%:R)) *: encmat) *m pmat' = scalar_mx 1.
- Proof.
- intros.
- apply invmx_comm.
- apply encmat_pmat.
- Qed.
-
-Lemma decode_encode_off_diag_T (n : nat):
- let pmat := (peval_mat (odd_nth_roots' (S n))) in
- forall n1 n2,
- n1 <> n2 ->
- H_inner_prod (col n1 pmat)^T (col n2 pmat)^T = C0.
-Proof.
- intros.
- rewrite !tr_col -H_inner_prod_mat trmxK.
- generalize (encmat_pmat_alt n); intros.
- simpl in H0.
- apply (f_equal (fun m => trmx ((RtoC (2 ^ n.+1)%:R) *: m))) in H0.
- rewrite scalemxAl scalerA in H0.
- replace ((RtoC (2 ^ n.+1)%:R * RtoC (2 ^ n.+1)%:R^-1)) with (RtoC 1) in H0.
- - rewrite trmx_mul scale1r conj_transpose trmxK scalemx1 in H0.
- rewrite /pmat H0 tr_scalar_mx mxE.
- case: (eqVneq n1 n2); intros.
- + by rewrite e in H.
- + by rewrite /= mulr0n.
- - rewrite -rmorphM /= divff // natmul0eq.
- lia.
-Qed.
-
- Lemma encmat_pmat_pvec (n : nat) (p : {poly R}) :
- let pmat' := peval_mat (odd_nth_roots' (S n)) in
- let encmat := (conj_mat (pmat'^T)) in
- let pvec := (poly_rV (d := (sval (pow2_S n.+1)).+1)
- (map_poly (aR:=R_ringType)
- (rR:=ComplexField.complex_ringType
- R_fieldType) RtoC p))^T in
- (((RtoC (inv (2^S n)%:R)) *: encmat) *m pmat') *m pvec = pvec.
- Proof.
- by rewrite /= encmat_pmat_alt mul1mx.
- Qed.
-
- Lemma big_max_split {k2 : nat} (k1 : nat) (F : 'I_k2 -> R) :
- \big[Order.max/0]_(j < k2) F j =
- Order.max
- (\big[Order.max/0]_(j < k2 | true && (j < k1)%N) F j)
- (\big[Order.max/0]_(j < k2 | true && ~~(j < k1)) F j).
- Proof.
- rewrite -Order.TotalTheory.bigmaxID.
- by apply eq_bigr.
- Qed.
-
-Lemma big_max_nneg_with_trailing_zeros {k1 k2} (le12: k1 <= k2) (F: 'I_k2 -> R) :
- (forall i, Rle 0 (F i)) ->
- (forall i: 'I_k2 , k1 <= i -> F i = 0%R) ->
- \big[Order.max/0]_(j < k2) F j = \big[Order.max/0]_(j < k1) F (widen_ord le12 j).
- Proof.
- intros Fnneg Ftrail0.
- rewrite (big_max_split k1).
- assert (\big[Order.max/0]_(j < k2 | true && ~~ (j < k1)) F j = 0).
- {
- under eq_bigr.
- intros.
- rewrite Ftrail0; try lia.
- over.
- rewrite big_const_seq iter_fix // -RmaxE /zero/= Rmax_left //; coq_lra.
- }
- rewrite H -RmaxE Rmax_left.
- destruct k1.
- - rewrite big_ord0.
- pose G : ('I_k2 -> R_orderType) := fun=> 0%R.
- assert (\big[Order.max/0]_(j < k2 | true && (j < 0)) F j =
- \big[Order.max/0]_(j < k2 | true && (j < 0)) G j).
- {
- apply congr_big; trivial.
- intros ??.
- lia.
- }
- rewrite H0 /G.
- rewrite big_const_seq iter_fix // -RmaxE /zero/= Rmax_left //; coq_lra.
- - rewrite [RHS](big_ord_widen_leq k2); trivial.
- apply eq_bigr.
- intros.
- f_equal.
- apply ord_inj => /=.
- by rewrite inordK.
- - pose G : ('I_k2 -> R_orderType) := fun=> 0%R.
- assert ((\big[Order.max/0]_(j < k2 | true && (j < k1)%N) (G j)) = 0)%R.
- {
- rewrite /G big_const_seq iter_fix // -RmaxE /zero/= Rmax_left //; coq_lra.
- }
- rewrite -{1}H0.
- apply /RleP.
- apply (@Order.TotalTheory.le_bigmax2 _ R_orderType).
- intros.
- rewrite /G.
- apply /RleP.
- apply Fnneg.
- Qed.
-
-
-
-
-
- Lemma coef_maxnorm_pvec n (p : {poly R}) :
- size p <= 2^n.+1 ->
- let pvec := (poly_rV (d := (sval (pow2_S n.+1)).+1)
- (map_poly (aR:=R_ringType)
- (rR:=ComplexField.complex_ringType
- R_fieldType) RtoC p))^T in
- coef_maxnorm p = cvec_norm_inf pvec.
- Proof.
- intros.
- rewrite /coef_maxnorm /cvec_norm_inf /norm_inf.
- assert ((sval (pow2_S n.+1)).+1 = (2^n.+1)%N).
- {
- simpl.
- lia.
- }
- case : (eqVneq (size p) (sval (pow2_S n.+1)).+1); intros.
- - rewrite e /=.
- apply eq_big_seq => k HH.
- rewrite /pvec !mxE.
- rewrite /map_poly coef_poly /=.
- case: ltP; rewrite normc_Rabs // => kbig.
- rewrite nth_default //.
- lia.
- - have le1: (size p <= (sval (pow2_S n.+1)).+1) by lia.
- rewrite (big_max_nneg_with_trailing_zeros le1).
- + apply eq_bigr => j _.
- rewrite /pvec /poly_rV !mxE.
- rewrite /= /map_poly coef_poly.
- destruct j.
- simpl in i0.
- by rewrite /= i0 normc_Rabs.
- + intros.
- apply/RleP.
- apply normc_nneg.
- + intros.
- rewrite /pvec /= /poly_rV.
- rewrite !mxE /= map_polyE.
- rewrite nth_default.
- * by rewrite ComplexField.Normc.normc0.
- * eapply leq_trans.
- -- apply size_Poly.
- -- by rewrite size_map.
- Qed.
-
- Lemma canon_norm_inf_pvec n (p : {poly R}) :
- size p <= 2^n ->
- let pmat' := peval_mat (odd_nth_roots' n) in
- let pvec := (poly_rV (d := (sval (pow2_S n)).+1)
- (map_poly (aR:=R_ringType)
- (rR:=ComplexField.complex_ringType
- R_fieldType) RtoC p))^T in
- canon_norm_inf n p = cvec_norm_inf (pmat' *m pvec).
- Proof.
- intros.
- rewrite /canon_norm_inf /cvec_norm_inf /norm_inf.
- apply eq_big; trivial; intros.
- f_equal.
- rewrite /odd_nth_roots' /pmat' /pvec /peval_mat !mxE /map_poly horner_poly.
- rewrite /odd_nth_roots'.
- under [RHS]eq_bigr do rewrite !mxE mulrC coef_poly.
- simpl.
- case : (eqVneq (size p) (2^n-1).+1); intros.
- - rewrite e.
- apply eq_big_seq => k HH.
- f_equal.
- assert (k < (2 ^ n - 1).+1).
- {
- by destruct k; simpl.
- }
- by rewrite H1.
- - transitivity (\sum_(i1 < size p + ((2^n-1).+1-size p)%nat)
- (if i1 < size p then RtoC p`_i1 else 0) *
- nth_root (2 * i + 1) (2 ^ n.+1) ^+ i1);
- [| by have ->: (size p + ((2^n-1).+1 - size p) = (2^n-1).+1)%nat by lia].
- rewrite big_split_ord /=.
- assert (\sum_(i1 < (2 ^ n - 1).+1 - size p)
- (if size p + i1 < size p then RtoC p`_(size p + i1) else 0) *
- nth_root (2 * i + 1) (2 ^ n.+1) ^+ (size p + i1) = 0).
- {
- under eq_bigr => si _.
- {
- have ->: ((size p + si < size p) = false) by lia.
- rewrite mul0r.
- over.
- }
- by rewrite big_const_seq iter_addr_0 mul0rn.
- }
- rewrite H1 addr0.
- by apply eq_big => // [[? /= ->]].
- Qed.
-
- Lemma matrix_norm_inf_pmat_inv n :
- let pmat' := peval_mat (odd_nth_roots' (S n)) in
- let encmat := (conj_mat (pmat'^T)) in
- Rle (matrix_norm_inf ((RtoC (inv (2^S n)%:R)) *: encmat)) 1.
- Proof.
- generalize (encode_mat_norm_inf' n); intros.
- simpl in H.
- move /RlebP in H.
- rewrite matrix_norm_inf_scale normc_Rabs.
- apply Rmult_le_compat_l with (r := Rabs (2 ^ n.+1)%:R^-1) in H.
- - eapply Rle_trans.
- apply H.
- right.
- assert (Rlt 0 (2 ^ n.+1)%:R).
- {
- rewrite -INRE.
- apply lt_0_INR.
- lia.
- }
- rewrite Rabs_right.
- + rewrite /inv/= RmultRinvx // unitrE divff //.
- rewrite /zero/=.
- apply/eqP.
- apply Rgt_not_eq.
- apply H0.
- + left.
- rewrite -RinvE.
- * apply Rinv_0_lt_compat; trivial.
- * rewrite /zero/=.
- apply/eqP.
- apply Rgt_not_eq.
- apply H0.
- - apply Rabs_pos.
- Qed.
-
- Lemma Rmult_le_1 (r1 r2 : R) :
- Rle r1 1 ->
- Rle 0 r2 ->
- Rle (r1 * r2) r2.
- Proof.
- intros.
- apply Rmult_le_compat_r with (r := r2) in H; trivial.
- by rewrite Rmult_1_l in H.
- Qed.
-
- Lemma matrix_norm_inf_pos {n m} (A : 'M[R[i]]_(n,m)) :
- Rle 0 (matrix_norm_inf A).
- Proof.
- rewrite /matrix_norm_inf.
- apply /RlebP.
- apply bigmax_nneg.
- intros.
- apply sum_nneg.
- intros.
- apply /RlebP.
- apply normc_nnegR.
- Qed.
-
- Theorem coef_maxnorm_le_canon_norm_inf n (p : {poly R}) :
- size p <= 2^n.+1 ->
- Rle (coef_maxnorm p) (canon_norm_inf n.+1 p).
- Proof.
- intros.
- rewrite (coef_maxnorm_pvec n) //.
- rewrite canon_norm_inf_pvec //.
- rewrite -{1}encmat_pmat_pvec /cvec_norm_inf.
- rewrite !mat_vec_norm_inf !trmxK.
- eapply Rle_trans.
- - rewrite -mulmxA.
- apply /RlebP.
- apply matrix_norm_inf_sub_mult.
- - generalize (matrix_norm_inf_pmat_inv n); intros.
- simpl in H0.
- apply Rmult_le_1; trivial.
- apply matrix_norm_inf_pos.
- Qed.
-
- Theorem canon_norm_inf_bounds n (p : {poly R}) :
- size p <= 2^n.+1 ->
- (coef_maxnorm p <= canon_norm_inf n.+1 p <= coef_norm1 p)%O.
- Proof.
- intros.
- apply /andP.
- split.
- - apply /RleP.
- by apply coef_maxnorm_le_canon_norm_inf.
- - apply canon_norm_inf_le_norm1.
- Qed.
-
- Lemma canon_norm_zero_mod_qpoly n (p : {poly R}) :
- canon_norm_inf n p = 0 ->
- Pdiv.Ring.rmodp (R:=R_ringType) p ('X^(2 ^ n) + 1%:P) = 0.
- Proof.
- intros.
- apply odd_nth_roots_minpoly_mult_R.
- intros.
- unfold root.
- apply /eqP.
- generalize (canon_norm_inf_val n p); intros.
- destruct i.
- assert (m < (2^n-1).+1) by lia.
- specialize (H0 (Ordinal H1)).
- rewrite H in H0.
- apply ComplexField.Normc.eq0_normc.
- apply Order.POrderTheory.le_anti.
- apply /andP; split.
- - replace (odd_nth_roots n 0 (Ordinal (n:=2^n) i)) with
- (odd_nth_roots' n 0 (Ordinal (n:=(2^n-1).+1) H1)).
- + apply H0.
- + by rewrite /odd_nth_roots' /odd_nth_roots !mxE.
- - apply normc_nneg.
- Qed.
-
-(* following only holds on quotient ring by x^+(2^n) + 1
- Lemma canon_norm_inf_pos_def n p :
- canon_norm_inf n p = 0 -> p = 0.
-*)
-
-
- Lemma normc_conj_mul (x y : R[i]) :
- normc (x * y) = normc (x * (conjc y)).
- Proof.
- by rewrite !ComplexField.Normc.normcM (normc_conj y).
- Qed.
-
- Lemma normc_conj_add (r : R) (x y : R[i]) :
- normc (x + y) = normc (conjc x + conjc y).
- Proof.
- by rewrite -rmorphD normc_conj.
- Qed.
-
- Lemma normc_conj_exp (x : R[i]) n :
- normc (x ^+ n) = normc ((conjc x) ^+ n).
- Proof.
- by rewrite -rmorphXn normc_conj.
- Qed.
-
- Lemma RtoC1 : RtoC 1 = 1.
- Proof.
- by [].
- Qed.
-
- Lemma RtoC0E (c:R) : (RtoC c == 0) = (c == 0).
- Proof.
- by rewrite /RtoC !eqE /= !eqE /= eqxx !andbT.
- Qed.
-
- Lemma RtoC_real a : RtoC a \is ssrnum.Num.real.
- Proof.
- by rewrite complex_real.
- Qed.
-
- Lemma conjc_id (a:R[i]) : (a^* = a)%C <-> (a \is ssrnum.Num.real).
- Proof.
- split.
- - move/Cconj_im_0 => eqq.
- apply/ssrnum.Num.Theory.Creal_ImP.
- by rewrite -complexIm /= eqq.
- - rewrite /in_mem/mem /= /Order.le /=.
- case: a => ra ia /= /orP-[/andP-[/eqP-> _]|/andP-[/eqP<- _]]
- ; by rewrite oppr0.
- Qed.
-
- Lemma conjc_RtoC a : ((RtoC a)^* )%C = RtoC a.
- Proof.
- by rewrite /RtoC /= oppr0.
- Qed.
-
- Lemma rpoly_eval_conj (p : {poly R}) (x : R[i]) :
- let pc := map_poly RtoC p in
- pc.[x]^*%C = pc.[x^*%C].
- Proof.
- case: p => l llast.
- rewrite /horner /= !map_polyE /=.
- have/PolyK->: seq.last (RtoC 1) (seq.map RtoC l) != 0
- by rewrite seq.last_map RtoC0E.
- move => {llast}.
- elim: l => /=.
- - by rewrite oppr0.
- - move=> a l <-.
- by rewrite rmorphD rmorphM /RtoC /= oppr0.
- Qed.
-
- Lemma rpoly_eval_conj_R (p : {poly R}) (x : R[i]) :
- let pc := map_poly RtoC p in
- pc.[x] = pc.[x]^*%C ->
- pc.[x] = pc.[x^* %C].
- Proof.
- move=> /= -> . by rewrite rpoly_eval_conj.
- Qed.
-
- Lemma normc_conj_poly (p : {poly R}) (x : R[i]) :
- let pc := map_poly RtoC p in
- normc (pc.[x]) = normc (pc.[x^*%C]).
- Proof.
- by rewrite /= -rpoly_eval_conj normc_conj.
- Qed.
-
-End norms.
-
-Lemma pmat_normc_1 (n : nat) :
- let pmat := peval_mat (odd_nth_roots (S n)) in
- forall i j,
- Cmod (pmat i j) = 1.
-Proof.
- simpl; intros.
- unfold peval_mat, odd_nth_roots.
- rewrite !mxE.
- destruct (pow2_S (n.+2)).
- move => /eqP in i0.
- by rewrite i0 exp_nth_root nth_root_Cmod.
-Qed.
-
-Definition coefE :=
- (coef0, coef1, coefC, coefX, coefXn, coef_sumMXn,
- coefZ, coefMC, coefCM, coefXnM, coefMXn, coefXM, coefMX, coefMNn, coefMn,
- coefN, coefB, coefD, coef_even_poly, coef_odd_poly,
- coef_take_poly, coef_drop_poly, coef_cons, coef_Poly, coef_poly,
- coef_deriv, coef_nderivn, coef_derivn, coef_map, coef_sum,
- coef_comp_poly_Xn, coef_comp_poly).
-
-Lemma Cmod_normc (c : R[i]) :
- Cmod c = ComplexField.Normc.normc c.
-Proof.
- unfold Cmod, ComplexField.Normc.normc.
- destruct c.
- rewrite RsqrtE//.
- apply ssrnum.Num.Theory.addr_ge0; apply ssrnum.Num.Theory.sqr_ge0.
-Qed.
-
-Lemma Cmod_triang (x y : R[i]) :
- Rle (Cmod (x + y)) (Cmod x + Cmod y).
-Proof.
- rewrite !Cmod_normc.
- by apply normc_triangR.
-Qed.
-
-Lemma Cmod_mul (c1 c2 : R[i]):
- Cmod (c1 * c2) = (Cmod c1) * (Cmod c2).
-Proof.
- by rewrite !Cmod_normc ComplexField.Normc.normcM.
-Qed.
-
-Lemma cmod_1_delta (c1 c2 : R[i]) (delta : R) :
- Cmod c1 = 1 ->
- Rlt (Cmod c2) delta ->
- Rlt (Cmod (c1 * c2)) delta.
-Proof.
- intros.
- by rewrite Cmod_mul H mul1r.
-Qed.
-
-Lemma root_eval_bound_cpoly (c : R[i]) (p : {poly R[i]}) (δ : R) :
- Rle 0 δ ->
- Cmod c = 1 ->
- (forall i, Rle (Cmod (p`_ i)) δ) ->
- Rle (Cmod (p.[c])) (δ *+ (seq.size p)).
-Proof.
- intros δnneg Cnorm1 coeffsmall.
- rewrite -{1}[p]polyseqK horner_Poly.
- move: (polyseq p) coeffsmall => {p}.
- elim.
- - move=> _/=.
- rewrite mulr0n expr0n /= !mulr0n Rplus_0_l sqrt_0.
- by apply Rle_refl.
- - move=> a l IHl coeffsmall /=.
- rewrite mulrS [δ + _]addrC.
- rewrite !Cmod_normc.
- eapply Rle_trans; [apply normc_triangR | ].
- rewrite ComplexField.Normc.normcM.
- rewrite -!Cmod_normc Cnorm1 mulr1.
- rewrite /add /=.
- apply Rplus_le_compat.
- + apply IHl.
- move=>i.
- by move: coeffsmall => /(_ (i.+1))/=.
- + by move: coeffsmall => /(_ O) /= //.
-Qed.
-
-Lemma RtoC_real_complex (r : R) :
- RtoC r = real_complex _ r.
-Proof.
- reflexivity.
-Qed.
-
-Lemma Cmod_Rabs (r : R) :
- Cmod (RtoC r) = Rabs r.
-Proof.
- rewrite RtoC_real_complex /Cmod /=.
- rewrite (expr2 0) mulr0 Rplus_0_r.
- generalize (pow2_inv r (r^+ 2)); intros.
- by rewrite H.
-Qed.
-
-Lemma Cmod_0 :
- Cmod 0 = 0.
-Proof.
- unfold Cmod. simpl.
- by rewrite !expr2 !mulr0 Rplus_0_r sqrt_0.
-Qed.
-
-Lemma root_eval_bound (c : R[i]) (p : {poly R}) (δ : R) :
- Rle 0 δ ->
- Cmod c = 1 ->
- (forall i, Rle (Rabs (p`_ i)) δ) ->
- Rle (Cmod (map_poly RtoC p).[c]) (δ *+ seq.size p).
-Proof.
- intros.
- generalize (root_eval_bound_cpoly c (map_poly RtoC p) δ H H0); intros.
- rewrite (size_map_poly RtoC_rmorphism p) in H2.
- apply H2; intros.
- rewrite coefE.
- destruct (i < seq.size p)%N.
- - by rewrite Cmod_Rabs.
- - by rewrite Cmod_0.
-Qed.
-
-Lemma Cmod_sum (n : nat) (cl : 'I_n -> R[i]) :
- Rle (Cmod (\sum_i cl i)) (\sum_i (Cmod (cl i))).
-Proof.
- simpl.
- elim: (index_enum (ordinal_finType n)) => /=.
- - rewrite !big_nil /=.
- rewrite !expr2 !mulr0 Rplus_0_r sqrt_0.
- by apply Rle_refl.
- - move=> a l IHl.
- rewrite !big_cons /=.
- eapply Rle_trans; [ apply Cmod_triang |].
- apply Rplus_le_compat; trivial.
- by apply Rle_refl.
-Qed.
-
-Lemma const_sum (n : nat) (δ : R) :
- \sum_(i < n) δ = n%:R * δ.
-Proof.
- rewrite big_const_ord.
- induction n.
- + by rewrite /= mul0r.
- + by rewrite /= IHn mulrS mulrDl mul1r.
-Qed.
-
-Lemma leq_sum_R [I : Type] (r : list I) [P : pred I] [E1 E2 : I -> R] :
- (forall i : I, P i -> Rle (E1 i) (E2 i)) ->
- Rle (\sum_(i <- r | P i) E1 i) (\sum_(i <- r | P i) E2 i).
-Proof.
- apply big_ind2.
- - by right.
- - by apply Rplus_le_compat.
-Qed.
-
-Lemma bounded_sum (n : nat) (cl : 'I_n -> R) (δ : R) :
- Rle 0 δ ->
- (forall i, Rle (cl i) δ) ->
- Rle (\sum_i cl i) (n%:R * δ).
-Proof.
- intros.
- apply Rle_trans with (r2 := \sum_(i < n) δ).
- - by apply leq_sum_R.
- - right.
- by apply const_sum.
- Qed.
-
-Lemma Cabs_sum_bound (n : nat) (cl : 'I_n -> R[i]) (δ : R) :
- Rle 0 δ ->
- (forall i, Rle (Cmod (cl i)) δ) ->
- Rle (Cmod (\sum_i cl i)) (n%:R * δ).
-Proof.
- intros.
- eapply Rle_trans.
- apply Cmod_sum.
- by apply bounded_sum.
-Qed.
-
-Lemma decode_delta (n : nat) (cl : 'cV[R[i]]_(2^(S n))) (δ : R) :
- Rle 0 δ ->
- let pmat := peval_mat (odd_nth_roots (S n)) in
- (forall i, Rle (Cmod (cl i 0)) δ) ->
- forall i, Rle (Cmod ((pmat *m cl) i 0)) ((2^S n)%:R * δ).
-Proof.
- simpl; intros.
- rewrite !mxE.
- apply Cabs_sum_bound; trivial.
- intros.
- by rewrite Cmod_mul pmat_normc_1 mul1r.
- Qed.
-
- Lemma mul_dvdn_l (x d1 d2:nat) :
- (d1 * d2 %| x)%N -> (d1 %| x)%N.
- Proof.
- eapply dvdn_trans.
- apply dvdn_mulr.
- by apply dvdnn.
- Qed.
-
- Lemma mul_dvdn_r (x d1 d2:nat) :
- (d1 * d2 %| x)%N -> (d2 %| x)%N.
- Proof.
- rewrite mulnC.
- by apply mul_dvdn_l.
- Qed.
-
- Lemma modn_muln (x y b1 b2:nat) :
- x == y %[mod b1 * b2] -> x == y %[mod b1].
- Proof.
- wlog le_yx : x y / y <= x; last by (rewrite !eqn_mod_dvd //; apply mul_dvdn_l).
- by have [?|/ltnW ?] := leqP y x; last rewrite !(eq_sym (x %% _)%N); apply.
- Qed.
-
- Section unity.
- Context {T : comRingType}
- (z : T).
-
- Lemma two_pow_prim_root_alt (n:nat) :
- z ^+ (2^n) <> 1 ->
- z ^+ (2^n.+1) = 1 ->
- primitive_root_of_unity (2^(n.+1)) z.
- Proof.
- intros zpow_n1 zpow1.
- assert (root_of_unity (2^(n.+1)) z).
- {
- by apply /unity_rootP.
- }
- destruct (@prim_order_exists _ (2^(n.+1)) z).
- - destruct (pow2_S (n.+1)).
- move=> /eqP in i.
- by rewrite i.
- - by apply /unity_rootP.
- - assert (exists (k:nat), x = expn 2 k).
- {
- move=> /prime.dvdn_pfactor in i0.
- destruct i0.
- by reflexivity.
- by exists x0.
- }
- move: H0 i0 i => [x0 ->] i0 i.
- have HH: x0 = n.+1.
- {
- move: i0.
- rewrite div.dvdn_Pexp2l; try lia.
- rewrite leq_eqVlt => /orP-[/eqP//|x0lt].
- assert (HH:expn 2 n = muln (expn 2 x0) (expn 2 (n - x0))).
- {
- rewrite -expnD.
- f_equal.
- lia.
- }
- by rewrite HH exprM (prim_expr_order i) Theory.expr1n in zpow_n1.
- }
- by rewrite HH in i.
- Qed.
-
- Lemma two_pow_prim_root (n:nat) :
- -(one T) <> (one T) ->
- z ^+ (2^n) = -1 ->
- primitive_root_of_unity (2^(n.+1)) z.
- Proof.
- intros onem1 zpowm1.
- apply two_pow_prim_root_alt.
- - by rewrite -zpowm1 in onem1.
- - by rewrite expnSr exprM zpowm1 expr2 mulrNN mulr1.
- Qed.
-
- Lemma unity_gcd e1 e2 :
- z^+e1 = 1 ->
- z^+e2 = 1 ->
- z^+(gcdn e1 e2) = 1.
- Proof.
- intros.
- destruct e2.
- - by rewrite gcdn0.
- - assert (0 < e2.+1) by lia.
- destruct (egcdnP e1 H1).
- apply (f_equal (fun y => y^+kn)) in H.
- rewrite -exprM mulnC expr1n in H.
- apply (f_equal (fun z => z^+km)) in H0.
- by rewrite -exprM mulnC e exprD expr1n H mul1r gcdnC in H0.
- Qed.
-
-Lemma modn_sub i j m :
- i >= j ->
- (i == j %[mod m]) = (i - j == 0 %[mod m]).
-Proof.
- move/eqn_mod_dvd->.
- by rewrite mod0n.
-Qed.
-
-Lemma modn_sub_iff i j m :
- i >= j ->
- i = j %[mod m] <-> i - j = 0 %[mod m].
-Proof.
- move/modn_sub=>eqq.
- split; move/eqP
- ; [rewrite eqq | rewrite -eqq]
- ; by move/eqP.
-Qed.
-
-Lemma prim_root_inv (k n : nat) :
- primitive_root_of_unity (n.+1) z ->
- k < n.+1 ->
- (z^+k) * (z ^+ (n.+1 - k)) = 1.
-Proof.
- intros.
- rewrite -exprD.
- replace (k + (n.+1-k))%N with (n.+1)%N by lia.
- by apply prim_expr_order.
-Qed.
-
-Lemma prim_root_inv' (k n : nat) :
- n != 0%N ->
- primitive_root_of_unity n z ->
- k < n ->
- (z^+k) * (z ^+ (n - k)) = 1.
-Proof.
- intros.
- rewrite -exprD.
- replace (k + (n-k))%N with (n)%N by lia.
- by apply prim_expr_order.
-Qed.
-
-Lemma prim_root_pow_unique (k1 k2 n : nat) :
- primitive_root_of_unity n z ->
- z ^+ k1 = z^+ k2 <-> k1 = k2 %[mod n].
-Proof.
- intros.
- generalize (eq_prim_root_expr H k1 k2); intros.
- split; intros.
- - move /eqP in H1.
- rewrite H0 in H1.
- apply (eqP H1).
- - move /eqP in H1.
- rewrite -H0 in H1.
- apply (eqP H1).
-Qed.
-
-Lemma two_pow_prim_root_inv (k n : nat) :
- primitive_root_of_unity (2^n.+1) z ->
- k < 2^n.+1 ->
- (z^+k) * (z ^+ (2^n.+1 - k)) = 1.
-Proof.
- intros.
- apply prim_root_inv'; lia.
-Qed.
-
-Lemma prim_root_pow_sqr (k n : nat) :
- n != 0%N ->
- primitive_root_of_unity (2*n)%N z ->
- (z^+k)^+2 = 1 ->
- k = 0 %[mod n].
-Proof.
- intros.
- rewrite -exprM mulnC in H1.
- generalize (prim_root_pow_unique (2*k) 0%N (2*n)%N H0); intros.
- rewrite expr0 in H2.
- assert (0 < 2*n) by lia.
- rewrite H2 -muln_modr (modn_small H3) in H1.
- assert (k %% n = 0)%N by lia.
- by rewrite H4 mod0n.
-Qed.
-
-Lemma zero_modn_mod2n (k n : nat) :
- 0 < n ->
- k = 0 %[mod n] ->
- k = 0 %[mod 2*n] \/ k = n %[mod 2*n].
-Proof.
- move=> npos /eqP.
- move: (dvdn_eq n k).
- rewrite (modn_small npos) /dvdn => -> eqq1.
-
- have ->: (k = k%/n * n)%N.
- {
- symmetry.
- by apply /eqP.
- }
- rewrite -muln_modl.
- have [-> | eqq]: (((k %/n) %% 2 = 0) \/ ((k %/n) %% 2 = 1))%N.
- {
- have: ((k %/ n) %% 2)%N < 2 by by apply ltn_pmod.
- lia.
- }
- - left; by rewrite mod0n mul0n.
- - right.
- rewrite -{3}(mul1n n) eqq -muln_modl.
- lia.
-Qed.
-
-Lemma two_pow_prim_root_m1 (k n : nat) :
- primitive_root_of_unity (2^n.+1) z ->
- -(one T) <> (one T) ->
- z^+k = -1 ->
- k = 2^n %[mod 2^n.+1].
- Proof.
- intros.
- assert (2^n != 0)%N by lia.
- rewrite expnS in H.
- assert (z ^+ k ^+ 2 = 1).
- {
- by rewrite H1 expr2 mulrNN mulr1.
- }
- generalize (prim_root_pow_sqr k (2^n) H2 H H3); intros.
- assert (k <> 0 %[mod 2^n.+1]).
- {
- unfold not; intros.
- rewrite -expnS in H.
- generalize (prim_root_pow_unique k 0 (2^n.+1) H); intros.
- rewrite expr0 H1 in H6.
- by rewrite -H6 in H5.
- }
- clear H H0 H1 H3 z T.
- assert (0 < 2^n)%N by lia.
- generalize (zero_modn_mod2n k (2^n) H H4); intros.
- rewrite expnS.
- rewrite expnS in H5.
- by destruct H0.
- Qed.
-
-Lemma two_pow_prim_root_m1_alt (n : nat) :
- primitive_root_of_unity (2^n.+1) z ->
- -(one T) <> (one T) ->
- z^+(2^n) <> -1 ->
- not (exists k, z^+k = -1).
-Proof.
- intros.
- unfold not; intros.
- destruct H2.
- generalize (two_pow_prim_root_m1 x n H H0 H2); intros.
- generalize (prim_expr_mod H); intros.
- by rewrite -H4 H3 H4 in H2.
-Qed.
-
- Lemma odd_pow_prim_root (n:nat) :
- z ^+ (2^n) = -1 ->
- forall j,
- odd j ->
- ((z ^+ j) ^+ (2^n)) = -1.
- Proof.
- intros.
- by rewrite exprAC H -signr_odd H0 /= expr1.
- Qed.
-
- Lemma gcd_odd_pow2 j n :
- odd j ->
- (div.gcdn j (2 ^ (S n)) = 1)%N.
- Proof.
- intros.
- generalize (div.coprime2n j); intros.
- assert (div.gcdn j 2 = 1%N).
- {
- rewrite H in H0.
- unfold div.coprime in H0.
- move => /eqP in H0.
- by rewrite div.gcdnC.
- }
- induction n; trivial.
- rewrite expnS.
- assert (div.coprime j 2).
- {
- unfold div.coprime.
- by apply /eqP.
- }
- now rewrite (@div.Gauss_gcdr j 2).
- Qed.
-
- Lemma pow2_odd_inv (j n :nat) :
- odd j ->
- {k | (j * k mod (2^(S n)) = 1)%N}.
- Proof.
- intros.
- assert (0 < j) by lia.
- destruct (div.egcdnP (2^(S n)) H0).
- exists km.
- rewrite mulnC e addnC Nat.mod_add; try lia.
- rewrite (gcd_odd_pow2 j n H).
- apply Nat.mod_small.
- destruct (pow2_S n).
- move => /eqP in i0.
- rewrite expnS i0; lia.
- Qed.
-
- Lemma odd_pow_prim_root_inv (j n:nat) :
- odd j ->
- exists k,
- forall (z2:T),
- z2 ^+ (2^n) = -1 ->
- ((z2 ^+ j) ^+ k) = z2.
- Proof.
- intros.
- assert (2^(S n) <> 0)%N.
- {
- destruct (pow2_S (S n)).
- move => /eqP in i.
- rewrite i.
- lia.
- }
- generalize (pow2_odd_inv j n H); intros.
- destruct H1 as [k ?].
- exists k.
- intros.
- assert (z2 ^+ (2 ^ (S n)) = 1).
- {
- by rewrite expnS mulnC exprM H1 expr2 mulrNN mulr1.
- }
- rewrite -exprM.
- rewrite (Nat.div_mod (j * k) _ H0) e.
- by rewrite exprD exprM H2 expr1 Theory.expr1n mul1r.
- Qed.
-
- Lemma odd_pow_prim_inv (n:nat) :
- z ^+ (2^n) = -1 ->
- forall j,
- ((z ^+ j) ^+ (2^(S n) -1)) * (z ^+ j) = 1.
- Proof.
- intros.
- rewrite -exprM -exprD /= -{2}(muln1 j) -mulnDr mulnC exprM.
- rewrite addBnCAC; try lia.
- by rewrite subnn add0n expnS mulnC exprM H expr2 mulrNN mulr1 expr1n.
- Qed.
-
- Lemma odd_pow_prim_root_inj (j n:nat) (z2 : T) :
- z ^+ (2^n) = -1 ->
- z2 ^+ (2^n) = -1 ->
- z <> z2 ->
- odd j ->
- z ^+ j <> z2 ^+ j.
- Proof.
- intros.
- unfold not; intros.
- destruct (odd_pow_prim_root_inv j n H2) as [k ?].
- apply H4 in H.
- apply H4 in H0.
- apply (f_equal (fun z => z ^+k)) in H3.
- by rewrite H H0 in H3.
- Qed.
-
- Lemma nth_root_eq' j k (n:nat) : n != 0%nat ->
- j mod n = k mod n <->
- nth_root j n = nth_root k n.
- Proof.
- destruct n; [lia |]=>_.
- apply nth_root_eq.
- Qed.
-
- Lemma expn_n0 (c:nat) n : c != 0%nat -> expn c n != 0%nat.
- Proof.
- lia.
- Qed.
-
-
- Lemma iff_eqb (a b:bool) : (a <-> b) <-> a = b.
- Proof.
- destruct a; destruct b; simpl in *; firstorder.
- elim H => //.
- Qed.
-
- Lemma odd_rep (n : nat) :
- odd n -> n = (2*(n%/2)+1)%N.
- Proof.
- intros.
- generalize (divn_eq n 2); intros.
- assert (n %% 2 = 1)%N.
- {
- by rewrite modn2 H.
- }
- by rewrite H1 in H0; lia.
- Qed.
-
-
- Lemma modn2_odd (x : nat) :
- odd x <-> (x %% 2 = 1)%N.
- Proof.
- rewrite modn2.
- by case: (odd _).
- Qed.
-
- Lemma pow2_odd_rem1_odd (x n : nat) :
- (x %% 2^(n.+1) = 1)%N -> odd x.
- Proof.
- intros.
- rewrite expnS in H.
- have: (x %% 2 = 1)%N.
- {
- assert (1 %% (2 * 2^n) = 1)%N.
- {
- rewrite modn_small; lia.
- }
- rewrite -{3}H0 in H.
- move /eqP in H.
- apply modn_muln in H.
- move /eqP in H.
- rewrite H.
- rewrite modn_small; trivial.
- }
- by rewrite modn2_odd.
- Qed.
-
- Lemma pow2_odd_inv_aux (j x n : nat) :
- ((x * (2*j+1)) %% 2^(n.+1) = 1)%N ->
- exists x0, x = (2*x0+1)%N.
- Proof.
- intros.
- exists (x%/2)%N.
- apply odd_rep.
- apply pow2_odd_rem1_odd in H.
- rewrite oddM in H.
- move /andP in H.
- tauto.
- Qed.
-
- Lemma odd_pow_prim_roots_perm_eq (j n : nat) :
- let l := mkseq (fun i => nth_root (2 * i + 1) (2 ^ (S n))) (2^n) in
- perm_eq l (map (fun r => r^+(2*j+1)) l).
- Proof.
- assert (uniq (mkseq (fun i : nat => nth_root (2 * i + 1) (2 ^ n.+1)) (2 ^ n))).
- {
- apply /mkseq_uniqP.
- intros ?????.
- rewrite /in_mem /mem /= in H.
- rewrite /in_mem /mem /= in H0.
- destruct (pow2_S (S n)).
- rewrite (eqP i) -nth_root_eq -(eqP i) in H1.
- rewrite !Nat.mod_small in H1; try rewrite expnS; try lia.
- }
- assert (odd (2*j+1)) by lia.
- destruct (pow2_odd_inv (2*j+1) n H0).
- rewrite mulnC modulo_modn in e.
- apply uniq_perm; trivial.
- - rewrite map_inj_in_uniq // => a b.
- rewrite /mkseq => /mapP-[i inth ->] /mapP-[k knth ->].
- do 2 rewrite pow_nth_root' ?expn_n0 //.
- do 2 rewrite -nth_root_eq' ?expn_n0 //.
- rewrite !modulo_modn => HH.
- apply (f_equal (fun k => ((x %% 2^n.+1) * k) %% 2^n.+1)%N) in HH.
- do 2 rewrite modnMm mulnA -modnMm e mul1n in HH.
- by rewrite !modn_mod in HH.
- - move=> a.
- apply iff_eqb.
- split; rewrite /mkseq -map_comp /comp => /mapP-[i inth ->]; apply/mapP.
- + assert (exists x0, (2 * x0 +1)*(2 * j +1) mod 2^n.+1 = (2 * i +1)%N mod 2^n.+1).
- {
- destruct (pow2_odd_inv_aux _ _ _ e).
- rewrite H1 in e.
- exists ((2*(x0 * i)+x0+i)%N).
- rewrite !modulo_modn.
- replace (2 * (2 * (x0 * i) + x0 + i) + 1)%N with ((2 * x0 + 1)*(2 * i + 1))%N by lia.
- rewrite -mulnA (mulnC (2 * i + 1)%N _) mulnA.
- rewrite -modnMm e mul1n.
- by rewrite modn_mod.
- }
- destruct H1.
- exists (x0 mod 2^n).
- * rewrite mem_iota.
- apply /andP.
- split; try lia.
- rewrite add0n.
- generalize (Nat.mod_upper_bound x0 (2^n)); lia.
- * rewrite pow_nth_root'; try lia.
- rewrite -nth_root_eq'; try lia.
- rewrite -H1.
- rewrite (mulnC (2 * x0 + 1)%N).
- rewrite !modulo_modn -modnMm -(modnMm (2*j+1)%N).
- do 2 f_equal.
- rewrite muln_modr -expnS.
- assert (1 < 2^n.+1) by lia.
- generalize (modn_small H2); intros.
- by rewrite -{6}H3 modnDm.
- + exists ((2*(i*j)+i+j)%N mod 2^n).
- * rewrite mem_iota.
- apply /andP.
- split; try lia.
- rewrite add0n.
- generalize (Nat.mod_upper_bound (2 * (i * j) + i + j) (2^n)); lia.
- * rewrite pow_nth_root'; try lia.
- rewrite -nth_root_eq'; try lia.
- rewrite !modulo_modn muln_modr -expnS.
- assert (1 < 2^n.+1) by lia.
- generalize (modn_small H1); intros.
- rewrite -{9}H2 modnDm.
- f_equal; lia.
- Qed.
-
- Lemma injective_finite_bijective {S} (l : list S) (f : S -> S) :
- NoDup l ->
- (forall s, In s l -> In (f s) l) ->
- injective f ->
- forall s, In s l <-> In s (map f l).
- Proof.
- intros.
- split; intros.
- - assert (NoDup (map f l)).
- {
- now apply FinFun.Injective_map_NoDup.
- }
- assert (incl l (map f l)).
- {
- apply NoDup_length_incl; trivial.
- - now rewrite map_length.
- - intros ??.
- apply in_map_iff in H4.
- destruct H4 as [? [??]].
- specialize (H0 x H5).
- now rewrite H4 in H0.
- }
- now apply H4.
- - apply in_map_iff in H2.
- destruct H2 as [? [??]].
- rewrite -H2.
- now apply H0.
- Qed.
-
- Lemma injective_finite_permutation {S} (l : list S) (f : S -> S) :
- NoDup l ->
- (forall s, In s l -> In (f s) l) ->
- injective f ->
- Permutation l (map f l).
- Proof.
- intros.
- apply NoDup_Permutation; trivial.
- - now apply FinFun.Injective_map_NoDup.
- - now apply injective_finite_bijective.
- Qed.
-
-
-
- Lemma char_2_opp_eq :
- one T = - (one T) <-> 2%N \in [char T].
- Proof.
- unfold mem, in_mem; simpl.
- rewrite mulr2n.
- split;intros.
- - apply (f_equal (fun (z : T) => 1 + z)) in H.
- rewrite addrN in H.
- by rewrite H.
- - move=> /eqP in H.
- apply (f_equal (fun (z : T) => z - 1)) in H.
- by rewrite add0r -addrA addrN addr0 in H.
- Qed.
-
- End unity.
-
-
diff --git a/coq/FHE/encrypt.v b/coq/FHE/encrypt.v
deleted file mode 100644
index ad7d1fa6..00000000
--- a/coq/FHE/encrypt.v
+++ /dev/null
@@ -1,764 +0,0 @@
-Require Import Lia List.
-From mathcomp Require Import common ssreflect fintype bigop ssrnat matrix ring.
-From mathcomp Require Import ssralg ssrfun ssrint seq.
-From mathcomp Require Import generic_quotient ring_quotient.
-From mathcomp Require Import poly mxpoly polydiv zmodp eqtype ssrbool.
-From mathcomp Require Import intdiv.
-
-Import ssralg.GRing.
-Require Import encode.
-
-Set Bullet Behavior "Strict Subproofs".
-
-Local Open Scope ring_scope.
-
-Lemma nat_of_ordK {p: nat} : cancel (@nat_of_ord (S (S p))) (natmul 1).
-Proof.
- move=> x.
- by rewrite Zp_nat valZpK.
-Qed.
-
-Lemma int_of_ordK {p: nat} : cancel (fun x:'Z_p => Posz (nat_of_ord x)) (intmul 1).
-Proof.
- move=> x.
- by rewrite -pmulrn nat_of_ordK.
-Qed.
-
-Section balance.
-
- Import ssrnum.Num.Syntax.
-
- (* range (-p/2, p/2] *)
- Definition balanced_mod {p:nat} (x : 'Z_p):int :=
- if (x <= p./2)%N then x%:Z else x%:Z-p%:Z.
-
- (* range [-p/2, p/2) *)
- Definition balanced_mod_lo {p:nat} (x : 'Z_p):int :=
- let xz := x %:Z in
- let xzm := xz - p%:Z in
- if -(p./2)%:Z <= xzm then xzm else xz.
-
-
- Lemma absz_bound (x : int) (b : nat) :
- - (b%:Z) <= x /\ x <= b%:Z <->
- (absz x <= b)%N.
- Proof.
- unfold absz.
- split; intros.
- - destruct H.
- case_eq x; intros; lia.
- - case_eq x; intros; rewrite H0 in H; lia.
- Qed.
-
- Lemma absz_bound_lt (x : int) (b : nat) :
- - (b%:Z) < x /\ x < b%:Z <->
- (absz x < b)%N.
- Proof.
- unfold absz.
- split; intros.
- - destruct H.
- case_eq x; intros; lia.
- - case_eq x; intros; rewrite H0 in H; lia.
- Qed.
-
- Context {p : nat} {pbig:(1 < p)%nat}.
-
- Lemma Zp_intmul_Np (x : 'Z_p) :
- x = (x%:Z - p%:Z)%:~R.
- Proof.
- generalize (intmul1_is_rmorphism (Zp_ringType (Zp_trunc p))); intros.
- destruct H.
- by rewrite base int_of_ordK -pmulrn char_Zp // oppr0 addr0.
- Qed.
-
- Import order.Order.TotalTheory.
- Import ssrnum.Num.Theory.
-
- Lemma balanced_mod_cong (x : 'Z_p) :
- x = (balanced_mod x)%:~R.
- Proof.
- unfold balanced_mod.
- case: (x <= p./2)%N.
- - by rewrite int_of_ordK.
- - by rewrite {1}(Zp_intmul_Np x).
- Qed.
-
- Lemma balanced_mod_lo_cong (x : 'Z_p) :
- x = (balanced_mod_lo x)%:~R.
- Proof.
- unfold balanced_mod_lo.
- case: leP => _.
- - by rewrite {1}(Zp_intmul_Np x).
- - by rewrite int_of_ordK.
- Qed.
-
- Lemma Zp_lt_p (x : 'Z_p):
- x%:Z < p.
- Proof.
- generalize (ltn_ord x); intros.
- by rewrite {2}Zp_cast in H.
- Qed.
-
- Lemma Zp_lt_p_N (x : 'Z_p):
- (x < p)%N.
- Proof.
- generalize (ltn_ord x); intros.
- by rewrite {2}Zp_cast in H.
- Qed.
-
- Lemma balanced_mod_range1 (x : 'Z_p):
- balanced_mod x <= p./2.
- Proof.
- unfold balanced_mod.
- generalize (Zp_lt_p_N x).
- intros.
- case leqP; lia.
- Qed.
-
- Lemma balanced_mod_lo_range1 (x : 'Z_p):
- balanced_mod_lo x <= p.-1./2.
- Proof.
- unfold balanced_mod_lo.
- generalize (Zp_lt_p x).
- case: (boolP (- (p./2)%:Z <= _)) => le1; lia.
- Qed.
-
- Lemma balanced_mod_range2 (x : 'Z_p):
- -((p.-1./2)%:Z) <= balanced_mod x.
- Proof.
- unfold balanced_mod.
- case: leqP => HH; try lia.
- Qed.
-
- Lemma balanced_mod_lo_range2 (x : 'Z_p):
- -((p./2)%:Z) <= balanced_mod_lo x.
- Proof.
- unfold balanced_mod_lo.
- case: (boolP (_ <= Posz x - p%:Z)) => le1; lia.
- Qed.
-
- Lemma balanced_mod_abs_range (x : 'Z_p):
- (absz (balanced_mod x) <= p./2)%N.
- Proof.
- apply absz_bound.
- split.
- - generalize (balanced_mod_range2 x); lia.
- - apply balanced_mod_range1.
- Qed.
-
- Lemma balanced_mod_lo_abs_range (x : 'Z_p):
- (absz (balanced_mod_lo x) <= p./2)%N.
- Proof.
- apply absz_bound.
- split.
- - apply balanced_mod_lo_range2.
- - generalize (balanced_mod_lo_range1 x); lia.
- Qed.
-
- Lemma balanced_mod_unique (c1 c2 : int) :
- c1 <= p./2 ->
- c2 <= p./2 ->
- -((p.-1./2)%:Z) <= c1 ->
- -((p.-1./2)%:Z) <= c2 ->
- ((c1 - c2) %% p)%Z = 0 ->
- c1 = c2.
- Proof.
- intros.
- case (leP (0%Z) (c1 - c2)%Z) => le0.
- - assert (le0_lep:0%Z <= c1 - c2 < p%:Z).
- {
- apply /andP; lia.
- }
- generalize (modz_small le0_lep); lia.
- - assert (le0_lep:0%Z <= c2 - c1 < p%:Z).
- {
- apply /andP; lia.
- }
- assert (((c2 - c1) %% p)%Z = 0).
- {
- replace (c2 - c1)%Z with (-1 * (c1 - c2))%Z by lia.
- by rewrite -modzMmr H3 mulr0 mod0z.
- }
- generalize (modz_small le0_lep); lia.
- Qed.
-
-End balance.
-
-Section encrypted_ops.
-
- Variable (q:nat).
- Hypothesis (qodd : (odd q)).
-
- Variable (err_poly secret_poly a_poly : {poly 'Z_q}).
- Variable (ρ : nat). (* bound for errs *)
-
- (* err_poly is small, a_poly is random over 'Z_q *)
-
- Definition public_key := (-a_poly * secret_poly + err_poly, a_poly).
-
- Definition encrypt (m : {poly 'Z_q}) : ({poly 'Z_q} * {poly 'Z_q}) :=
- (m + fst public_key, snd public_key).
-
- Definition encrypt_z (m : {poly int}) := encrypt (red_poly m q).
-
- Definition rounded_div (num : int) (den : nat) :=
- let denz := den %:Z in
- let q := (num %/ denz)%Z in
- let rem := num - q * denz in
- if absz rem <= den./2 then q else q+1.
-
- Lemma add_opp [R : comRingType] (x : R) :
- (-x) + x = 0.
- Proof.
- ring.
- Qed.
-
- Lemma rounded_div_rem_small (num : int) (den : nat) :
- (0 < den)%N ->
- absz (num - (rounded_div num den) * (den%:Z))%Z <= den ./2.
- Proof.
- intros.
- apply absz_bound.
- unfold rounded_div.
- case: (boolP ((`|(num - (num %/ den)%Z * den)%R|) <= _)) => HH.
- - apply absz_bound in HH.
- destruct HH.
- split; lia.
- - split; try lia.
- Qed.
-
- Definition coef_norm {qq:nat} (p : {poly 'Z_qq}) :=
- list_max (map absz (map balanced_mod p)).
-
- Hypothesis (err_poly_small : coef_norm err_poly <= ρ).
-
- Definition decrypt mpair := (fst mpair) + (snd mpair) * secret_poly.
-
- Lemma encrypt_decrypt (m : {poly 'Z_q}) :
- decrypt (encrypt m) = m + err_poly.
- Proof.
- unfold decrypt, encrypt, public_key, fst, snd.
- ring.
- Qed.
-
-(*
- (* following already defined in ssralg *)
- Definition add_pair (p1 p2 : ({poly 'Z_q} * {poly 'Z_q})) :=
- (fst p1 + fst p2, snd p1 + snd p2).
-*)
-
- Definition scale_pair (m : {poly 'Z_q}) (p : ({poly 'Z_q} * {poly 'Z_q})) :=
- (m * fst p, m * snd p).
-
- Definition mul_pair (p1 p2 : ({poly 'Z_q} * {poly 'Z_q})) :=
- (fst p1 * fst p2, (fst p1 * snd p2 + snd p1 * fst p2, snd p1 * snd p2)).
-
- Lemma CKKS_add (m1 m2 : {poly 'Z_q}) :
- decrypt (add_pair (encrypt m1) (encrypt m2)) =
- decrypt (encrypt m1) + decrypt(encrypt m2).
- Proof.
- rewrite !encrypt_decrypt.
- unfold add_pair, decrypt, encrypt, public_key, fst, snd.
- ring.
- Qed.
-
- Lemma CKKS_scale (m1 m2 : {poly 'Z_q}) :
- decrypt (scale_pair m1 (encrypt m2)) =
- m1 * decrypt(encrypt m2).
- Proof.
- unfold scale_pair, encrypt, decrypt, public_key, fst, snd.
- ring.
- Qed.
-
- Definition decrypt_mul trip := fst trip + secret_poly * decrypt (snd trip).
-
- Lemma CKKS_mul_trip (m1 m2 : {poly 'Z_q}) :
- decrypt_mul (mul_pair (encrypt m1) (encrypt m2)) =
- decrypt (encrypt m1) * decrypt (encrypt m2).
- Proof.
- unfold mul_pair, encrypt, decrypt_mul, decrypt, public_key, fst, snd.
- ring.
- Qed.
-
- Variable (p:nat). (* relin_modulus *)
- Hypothesis pbig : p > q.
-
- Definition pq_embed (c : 'Z_q) : 'Z_(p*q) := (balanced_mod c)%:~R.
-
- Definition secret_p := map_poly pq_embed secret_poly.
-
- Variable (relin_err relin_a : {poly 'Z_(p*q)}).
- Hypothesis (relin_err__small : coef_norm relin_err <= ρ).
-
- Definition rescale (q1 q2 : nat) (c : 'Z_(q1*q2)) : 'Z_q2 :=
- (rounded_div (balanced_mod c) q1)%:~R.
-
- Definition rescale_gen (q1 q2 : nat) (c : 'Z_q1) : 'Z_q2 :=
- (rounded_div ((balanced_mod c) * q2) q1)%:~R.
-
- Definition scale_up (q1 q2 : nat) (c : 'Z_q1) : 'Z_(q1*q2) :=
- inZp (muln q2 c).
-
- Lemma scale_up_additive (q1 q2 : nat):
- additive (scale_up (Zp_trunc q1).+2 (Zp_trunc q2).+2).
- Proof.
- intros x y.
- rewrite /scale_up /add /opp /= /inZp.
- apply ord_inj => /=.
- set q1' := (Zp_trunc q1).+2.
- set q2' := (Zp_trunc q2).+2.
- rewrite {2 4}(@Zp_cast (q1' * q2')) //.
- rewrite !div.modnDmr !div.modnDml.
- rewrite div.muln_modr.
- unfold q1', q2'.
- replace ((Zp_trunc ((Zp_trunc q1).+2 * (Zp_trunc q2).+2)).+2 ) with
- ((Zp_trunc q2).+2 * (Zp_trunc q1).+2)%N at 1 by (unfold Zp_trunc; lia).
- rewrite div.modn_mod.
- f_equal; try lia.
- rewrite div.modn_small.
- - rewrite mulnDr mulnBr.
- f_equal.
- f_equal.
- rewrite mulnC.
- unfold Zp_trunc.
- lia.
- - rewrite mulnC.
- have: (ltn y (Zp_trunc q1).+2).
- + apply ltn_ord.
- + assert (0 < (Zp_trunc q2).+2).
- {
- unfold Zp_trunc; lia.
- }
- by rewrite ltn_mul2r.
- Qed.
-
- Definition rescale1 (q1 q2 : nat) (c : 'Z_(q1*q2)) : 'Z_q2 := inZp c.
-
- Lemma rescale1_is_rmorphism (q1 q2 : nat) :
- rmorphism (rescale1 (Zp_trunc q1).+2 (Zp_trunc q2).+2).
- Proof.
- unfold rescale1.
- generalize (intmul1_is_rmorphism (Zp_ringType (Zp_trunc q2))); intros.
- destruct H as [? [? ?]].
- constructor.
- - intros x y.
- rewrite /add/opp/=/Zp_add /= /inZp.
- apply ord_inj => /=.
- set q1' := (Zp_trunc q1).+2.
- set q2' := (Zp_trunc q2).+2.
- rewrite {2 4}(@Zp_cast (q1' * q2')) //.
- rewrite !div.modnDmr !div.modnDml div.modn_dvdm.
- + suff: Posz (div.modn (x + (q1' * q2' - y)) q2') = Posz (div.modn (x + (q2' - div.modn y q2')) q2')
- by inversion 1.
- rewrite -!modz_nat !PoszD -!ssrint.subzn.
- * rewrite -modzDmr -(modzDml (muln q1' q2')).
- rewrite PoszM modzMl add0r modzDmr -!modz_nat.
- rewrite -(modzDmr x (Posz q2' - _)).
- rewrite -(modzDml q2' _).
- rewrite modzz add0r modzNm.
- by rewrite modzDmr.
- * apply ltnW.
- by apply div.ltn_pmod.
- * apply ltnW.
- apply ltn_ord.
- + by rewrite div.dvdn_mull.
- - constructor.
- + intros x y.
- rewrite -!Zp_nat !pmulrn.
- rewrite -m -!pmulrn !Zp_nat /inZp.
- apply ord_inj => //=.
- rewrite div.modn_dvdm //.
- rewrite {1}Zp_cast //.
- rewrite (@Zp_cast ((Zp_trunc q1).+2 * (Zp_trunc q2).+2)) //.
- by rewrite div.dvdn_mull //.
- + by rewrite /= div.modn_small //.
- Qed.
-
- Canonical rescale1_rmorphism (q1 q2: nat) := RMorphism (rescale1_is_rmorphism q1 q2).
-
- Lemma cdivqq_int (q1 q2 : nat) (c : int):
- (0 < q2)%N ->
- (c %/ q1)%Z = ((c * q2) %/ (q1 * q2)%N)%Z.
- Proof.
- intros.
- rewrite -(@divzMpr q2%:Z); [| lia].
- do 2 f_equal.
- Qed.
-
- Lemma lt_muln_iff (n1 n2 n3 : nat) :
- (n1 < n2)%N <-> (n1 * (S n3) < n2 * (S n3))%N.
- Proof.
- induction n3; lia.
- Qed.
-
- Lemma le_half_odd (n1 n2 : nat) :
- odd n2 ->
- (n1 <= n2./2)%N <-> (n1.*2.+1 <= n2)%N.
- Proof.
- lia.
- Qed.
-
- Lemma le_half_mul_odd (n1 n2 n3 : nat) :
- odd n2 ->
- odd n3 ->
- (n1 <= n2./2)%N <-> (n1 * n3 <= (n2 * n3)./2)%N.
- Proof.
- intros.
- rewrite le_half_odd // le_half_odd; try lia.
- replace ((n1 * n3).*2) with ((n1.*2)*n3)%N by lia.
- replace n3 with (n3.-1.+1) by lia.
- apply lt_muln_iff.
- Qed.
-
- Lemma rounded_div_scale_div (q1 q2 : nat) (c : int):
- odd q1 ->
- odd q2 ->
- rounded_div c q1 = rounded_div (c * q2) (q1 * q2).
- Proof.
- intros.
- assert (0 < q2)%N by lia.
- rewrite /rounded_div -!cdivqq_int //.
- have: ((c * q2 - (c %/ q1)%Z * (q1 * q2)%N)%R =
- (c - (c %/ q1)%Z * q1)%R * q2) by lia.
- move ->.
- rewrite abszM absz_nat.
- generalize (le_half_mul_odd `|(c - (c %/ q1)%Z * q1)%R| q1 q2 H H0); intros.
- case: leP; case: leP; lia.
- Qed.
-
- Lemma rescale_gen_prop (q1 q2 : nat) (c : 'Z_(q1*q2)):
- odd q1 ->
- odd q2 ->
- rescale q1 q2 c = rescale_gen (q1 * q2) q2 c.
- Proof.
- intros.
- unfold rescale, rescale_gen, balanced_mod.
- by rewrite -rounded_div_scale_div.
- Qed.
-
- Definition red_p_q (c : 'Z_(p*q)) : 'Z_q := rescale p q c.
-
- Definition relin_V2_aux (c2 : {poly 'Z_q}) :=
- let b := - relin_a * secret_p + (secret_p ^+ 2)*+p + relin_err in
- let cp := map_poly pq_embed c2 in
- (map_poly red_p_q (cp * b), map_poly red_p_q (cp * relin_a)).
-
- Definition relin_V2 trip :=
- add_pair (fst trip, fst (snd trip))
- (relin_V2_aux (snd (snd trip))).
-
- Definition CKKS_mul (p1 p2 : ({poly 'Z_q} * {poly 'Z_q})) :
- ({poly 'Z_q} * {poly 'Z_q}) :=
- relin_V2 (mul_pair p1 p2).
-
-End encrypted_ops.
-
-Section rotation.
-
- (* show p x -> p (x^k) is a morphism *)
- Definition poly_shift [R:ringType] (k : nat) (p : {poly R}) : {poly R}
- := comp_poly 'X^k p.
-
- Definition poly_shift_alt [R:ringType] (k : nat) (p : {poly R}) : {poly R}
- := \poly_(i < (k * (seq.size p).-1).+1) (if div.dvdn k i then p`_(div.divn i k) else 0).
-
- Lemma poly_shift_altE [R:ringType] (k : nat) (p : {poly R}) :
- poly_shift k.+1 p = poly_shift_alt k.+1 p.
- Proof.
- case: (@eqP _ (seq.size p) 0%nat).
- - move/seq.size0nil.
- rewrite -polyseq0 => /poly_inj->.
- rewrite /poly_shift /poly_shift_alt comp_poly0.
- apply polyP => i.
- rewrite coef_poly !coef0.
- case: ltP => //.
- by case: div.dvdnP => //.
- - move=> pn0.
- apply polyP => i.
- rewrite /poly_shift /poly_shift_alt.
- rewrite !coef_comp_poly_Xn //= coef_poly /=.
- case: div.dvdnP.
- + move=>[m ->].
- rewrite div.mulnK //.
- case: ltP => // nlt.
- rewrite seq.nth_default //.
- move /ltP: nlt.
- rewrite mulnC ltnS leq_pmul2l //.
- rewrite leqNgt Bool.negb_involutive.
- by case: (seq.size p) pn0.
- + by case: ltP.
- Qed.
-
- Lemma poly_shift_altE' [R:ringType] (k : nat) (p : {poly R}) : k != 0%nat ->
- poly_shift k p = poly_shift_alt k p.
- Proof.
- destruct k => // _.
- apply poly_shift_altE.
- Qed.
-
- Lemma poly_shift_1 [R:ringType] (k : nat):
- @poly_shift R k 1 = 1.
- Proof.
- by rewrite /poly_shift comp_polyC.
- Qed.
-
- Lemma poly_shift_is_rmorphism [R:comRingType] (k : nat) :
- rmorphism (poly_shift (R := R) k).
- Proof.
- unfold poly_shift.
- constructor.
- - intros ??.
- by rewrite comp_polyB.
- - split.
- + intros ??.
- by rewrite comp_poly_multiplicative.
- + by rewrite comp_polyC polyC1.
- Qed.
-
- Lemma poly_shift_injective [R:ringType] (k:nat) : injective (poly_shift (R:=R) (S k)).
- Proof.
- move=> a b eqq.
- apply polyP => i.
- apply (f_equal (coefp (k.+1 * i))) in eqq.
- move: eqq.
- rewrite /poly_shift /=.
- rewrite !coef_comp_poly_Xn //=.
- rewrite !div.dvdn_mulr //.
- by rewrite !div.mulKn //.
- Qed.
-
- Lemma poly_shift1_id [R:ringType] (p : {poly R}):
- @poly_shift R 1 p = p.
- Proof.
- apply polyP => i.
- rewrite /poly_shift /=.
- rewrite !coef_comp_poly_Xn //=.
- by rewrite div.dvd1n div.divn1.
- Qed.
-
- Lemma size_poly_shift [R:ringType] (k:nat) (p : {poly R}) (pn0:p!=0) :
- seq.size (poly_shift (k.+1) p) = (k.+1 * (seq.size p).-1).+1%nat.
- Proof.
- rewrite poly_shift_altE.
- rewrite size_poly_eq //=.
- rewrite div.dvdn_mulr ?div.dvdnn //.
- rewrite div.mulKn //.
- by rewrite -lead_coefE lead_coef_eq0.
- Qed.
-
- Lemma size_poly_shift' [R:ringType] (k:nat) (p : {poly R}) (pn0:p!=0) :
- k != 0%nat ->
- seq.size (poly_shift k p) = (k * (seq.size p).-1).+1%nat.
- Proof.
- elim: k => //.
- move=> k _ _.
- by apply size_poly_shift.
- Qed.
-
- Definition poly_unshift [R:ringType] (k : nat) (p : {poly R}) :=
- \poly_(i < (div.divn (seq.size p).-1 k).+1) (p`_(k*i)).
-
- Lemma poly_shiftK [R:comRingType] (k: nat) : cancel (@poly_shift R (S k)) (@poly_unshift R (S k)).
- Proof.
- move=> p.
- case: (@eqP _ p 0).
- - move=> -> /=.
- rewrite /poly_shift comp_poly0 /poly_unshift.
- apply polyP=> i.
- rewrite coef_poly !polyseq0 /= !seq.nth_nil.
- by case: ltP.
- - rewrite /poly_unshift => /eqP-pn0.
- rewrite size_poly_shift //.
- rewrite poly_shift_altE /poly_shift_alt.
- apply polyP=> i.
- rewrite coef_poly => /=.
- rewrite div.mulKn //.
- rewrite -polySpred //.
- case: ltP.
- + move=> ilt.
- rewrite coef_poly div.mulKn //.
- rewrite div.dvdn_mulr ?div.dvdnn //.
- rewrite ltnS leq_pmul2l //.
- rewrite polySpred // in ilt.
- rewrite -ltnS.
- by move/ltP: ilt => ->.
- + move=> inlt.
- rewrite seq.nth_default //.
- rewrite leqNgt.
- by apply/ltP.
- Qed.
-
-
- Lemma comp_poly_exp_polyX [R:ringType] j k :
- (polyX R) ^+ (j * k) = comp_poly ('X^ j) ('X^ k).
- Proof.
- by rewrite comp_Xn_poly /= -exprM.
- Qed.
-
- Lemma poly_shiftM [R:comRingType] (j k: nat) (p: {poly R}) :
- poly_shift (j * k) p = poly_shift j (poly_shift k p).
- Proof.
- by rewrite /poly_shift -comp_polyA comp_poly_exp_polyX.
- Qed.
-
- Lemma lin_div_odd_power [R:ringType] k :
- odd k ->
- Pdiv.Ring.rdvdp (R := R) ('X + 1%:P) ('X^k + 1%:P).
- Proof.
- rewrite -{1}(opprK 1%:P).
- replace (- polyC (R:=R) 1) with (polyC (R:=R) (-1)).
- - intros.
- rewrite Pdiv.Ring.rdvdp_XsubCl /root hornerD hornerXn hornerC.
- by rewrite -signr_odd H /= expr1 addrC addrN.
- - by rewrite polyCN polyC1.
- Qed.
-
- Lemma rdvdp_comp_poly_monic [R:comRingType] (r p q : {poly R}) :
- p \is monic ->
- p \Po r \is monic ->
- Pdiv.Ring.rdvdp p q ->
- Pdiv.Ring.rdvdp (p \Po r) (q \Po r).
- Proof.
- move=> monp monpr.
- have [-> | pn0] := eqVneq p 0.
- - by rewrite comp_poly0 !Pdiv.Ring.rdvd0p; move/eqP->; rewrite comp_poly0.
- - rewrite Pdiv.ComRing.rdvdp_eq.
- rewrite (monicP monp) expr1n /= scale1r.
- set s := Pdiv.Ring.rdivp (R:=R) _ _; move/eqP=> Hq.
- apply: (@mathcomp.algebra.polydiv.Pdiv.RingMonic.eq_rdvdp _ _ _ (s \Po r)).
- + trivial.
- + by rewrite -comp_polyM -{}Hq.
- Qed.
-
- Lemma pow2_div_odd_power [R:comRingType] k n :
- odd k ->
- Pdiv.Ring.rdvdp (R := R) ('X^(2^n) + 1%:P) ('X^k ^+(2^n) + 1%:P).
- Proof.
- move=> oddk.
- case: (@eqVneq _ n 0%nat).
- - move=> ->.
- rewrite expn0 !expr1.
- by apply lin_div_odd_power.
- - move=> nn0.
- move: (rdvdp_comp_poly_monic (R:=R) ('X^(2 ^ n)) ('X + 1%:P) ('X^k + 1%:P)).
- rewrite lin_div_odd_power //.
- rewrite (Xn_add_c_monic 0).
- rewrite !comp_polyD !comp_polyX !comp_polyC.
- have-> : 'X^(2 ^ n) + 1%:P \is @monic R.
- {
- case: (@eqP _ (expn 2 n) 0%nat) =>eqq.
- - lia.
- - destruct (expn 2 n); [lia |].
- apply Xn_add_c_monic.
- }
- rewrite comp_Xn_poly -!exprM.
- rewrite [muln (2^n) k]mulnC.
- by apply.
- Qed.
-
-End rotation.
-
- Require Import Reals nth_root encode.
- From mathcomp Require Import Rstruct complex.
-
- Lemma poly_shift_eval [S : comRingType] (p : {poly S}) (k : nat) (v : S) :
- p.[v^+k] = (poly_shift k p).[v].
- Proof.
- unfold poly_shift.
- by rewrite horner_comp hornerXn.
- Qed.
-
- Lemma poly_shift_C (p : {poly R}) (k : nat) :
- poly_shift k (map_poly RtoC p) = map_poly RtoC (poly_shift k p).
- Proof.
- by rewrite /poly_shift map_comp_poly map_polyXn.
- Qed.
-
- Lemma poly_shift_eval_C (p : {poly R}) (k:nat) (v : R[i]) :
- (map_poly RtoC p).[v^+k] = (map_poly RtoC (poly_shift k p)).[v].
- Proof.
- by rewrite poly_shift_eval poly_shift_C.
- Qed.
-
- Lemma conj_poly_eval_pow (p : {poly R}) (i j :nat) :
- let v := nth_root i (S j) in
- conjc ((map_poly RtoC p).[v]) = (map_poly RtoC (poly_shift j p)).[v].
- Proof.
- simpl.
- rewrite -poly_shift_eval_C.
- rewrite rpoly_eval_conj.
- f_equal.
- by rewrite -conj_pow_nth_root.
- Qed.
-
- From mathcomp Require Import ssrnat div.
- Lemma poly_odd_pow_prim_roots_perm_eq (j n : nat) (p : {poly R[i]}):
- let l := mkseq (fun i => nth_root (2 * i + 1) (2 ^ (S n))) (2^n) in
- perm_eq (map (fun x => p.[x]) l)
- (map (fun x => (poly_shift (2*j+1) p).[x]) l).
- Proof.
- pose (l := mkseq (fun i => nth_root (2 * i + 1) (2 ^ (S n))) (2^n)).
- have /= ->: map (fun x => (poly_shift (2*j+1) p).[x]) l =
- map (fun x => p.[x]) (map (fun x => x^+(2*j+1)) l).
- {
- rewrite -map_comp /ssrfun.comp.
- apply eq_map=> a.
- by rewrite poly_shift_eval.
- }
- apply perm_map.
- apply odd_pow_prim_roots_perm_eq.
- Qed.
-
- Lemma nth_root_pow_trans n i k :
- exists j,
- (nth_root (2 * i + 1) (2 ^ (S n))) ^+ (2 * j + 1) =
- nth_root (2 * k + 1) (2 ^ (S n)).
- Proof.
- assert (exists j,
- (2 * i + 1) * (2 * j + 1) = 2 * k + 1 %[mod 2^(S n)]).
- {
- generalize (pow2_odd_inv (2 * i + 1) n); intros.
- destruct H; try lia.
- generalize (pow2_odd_inv_aux i x n); intros.
- rewrite (mulnC x _) in H.
- rewrite modulo_modn in e.
- specialize (H e).
- destruct H.
- rewrite H in e.
- exists (2 * x0 * k + x0 + k)%N.
- replace (2 * (2 * x0 * k + x0 + k) + 1)%N with
- ((2 * x0 + 1) * (2 * k + 1))%N by lia.
- rewrite mulnA.
- assert (1 %% 2^n.+1 = 1)%N.
- {
- rewrite modn_small; lia.
- }
- rewrite -{6}H0 in e.
- apply (f_equal (fun z => (z * (2 * k + 1) %% 2^ n.+1)%N)) in e.
- by rewrite H0 mul1n modnMml in e.
- }
- destruct H.
- exists x.
- rewrite pow_nth_root'; try lia.
- apply nth_root_eq'; try lia.
- by rewrite mulnC !modulo_modn.
- Qed.
-
- Lemma poly_odd_pow_prim_roots_perm_trans (n : nat) (p : {poly R[i]}):
- forall i k,
- exists j,
- (poly_shift (2*j+1) p).[nth_root (2 * i + 1) (2 ^ (S n))] =
- p.[nth_root (2 * k + 1) (2 ^ (S n))].
- Proof.
- intros.
- destruct (nth_root_pow_trans n i k).
- exists x.
- by rewrite -poly_shift_eval H.
- Qed.
-
-
-
-
-
-
-
diff --git a/coq/FHE/nth_root.v b/coq/FHE/nth_root.v
deleted file mode 100644
index a5dafc2f..00000000
--- a/coq/FHE/nth_root.v
+++ /dev/null
@@ -1,885 +0,0 @@
-Require Import Reals Lra Lia List.
-From mathcomp Require Import complex ssreflect common eqtype ssrint ssrnat Rstruct.
-Import ssralg.GRing.
-Import ssralg.
-
-Ltac coq_lra := lra.
-
-From mathcomp Require Import lra.
-
-Set Bullet Behavior "Strict Subproofs".
-
-Lemma S_INR_not_0 n :
- INR (S n) <> Rdefinitions.R0.
-Proof.
- rewrite S_INR.
- generalize (pos_INR n).
- coq_lra.
-Qed.
-
-Lemma S_INR_n0 n : is_true (INR (S n) != (zero R_ringType)).
-Proof.
- intros.
- move: (S_INR_not_0 n) => HH.
- by case eqP.
-Qed.
-
-(* represent complex number as pair *)
-Definition nth_root (j n : nat) : R[i] :=
- let c := (2*PI*INR(j)/INR(n))%R in
- ((cos c) +i* (sin c))%C.
-
-Local Open Scope ring_scope.
-Delimit Scope complex_scope with C.
-Local Open Scope complex_scope.
-
-Definition RtoC (x : R) := Complex x Rdefinitions.R0.
-Definition C1 := Complex R1 Rdefinitions.R0.
-Definition C0 := Complex Rdefinitions.R0 Rdefinitions.R0.
-Definition C := ComplexField.complex_ringType R_fieldType.
-
-Lemma nth_root_0 n :
- nth_root 0 (S n) = 1.
-Proof.
- unfold nth_root.
- assert ((2 * PI * INR 0 / INR (S n))%R = 0%R).
- {
- unfold INR at 1.
- by rewrite mulr0 mul0r.
- }
- by rewrite H /zero /= cos_0 sin_0.
-Qed.
-
-Lemma nth_root_2PI n j :
- nth_root (j * (S n)) (S n) = 1.
-Proof.
- unfold nth_root.
- replace (2 * PI * INR (j * S n) / INR (S n))%R with
- (0 + 2 * (INR j) * PI)%R.
- - by rewrite cos_period sin_period /zero /= cos_0 sin_0.
- - rewrite add0r mult_INR -mulrA -mulrA.
- rewrite (mulrC _ PI) -(mulrA 2 _).
- f_equal.
- f_equal.
- rewrite -mulrA divff.
- + by rewrite mulr1.
- + by apply S_INR_n0.
-Qed.
-
-Lemma nth_root_2PI_plus n j k :
- nth_root (j + k * (S n)) (S n) = nth_root j (S n).
-Proof.
- unfold nth_root.
- replace (2 * PI * INR (j + k * S n) / INR (S n))%R with
- (2 * PI * INR(j)/INR(S n) + 2 * INR k * PI)%R.
- - now rewrite cos_period; rewrite sin_period.
- - rewrite plus_INR; rewrite mult_INR.
- rewrite (mulrDr (2 * PI) _ _) mulrDl.
- f_equal.
- rewrite -mulrA (mulrC _ PI) -mulrA -mulrA.
- f_equal.
- f_equal.
- rewrite -mulrA divff.
- + by rewrite mulr1.
- + by apply S_INR_n0.
-Qed.
-
-Definition nth_roots (n:nat) :=
- map (fun j => nth_root j n) (seq 0 n).
-
-
-Lemma de_moivre (x : R) (n : nat) :
- exp (cos x +i* sin x) n = (cos ((INR n) * x)%R +i* sin ((INR n) * x)%R).
-Proof.
- rewrite /mul /= -iter_mulr_1.
- induction n.
- - simpl.
- rewrite Rmult_0_l.
- now rewrite cos_0; rewrite sin_0.
- - simpl iter.
- rewrite IHn S_INR Rmult_plus_distr_r Rmult_1_l.
- rewrite cos_plus sin_plus /= /mul /=.
- rewrite /mul /add /opp /=.
- f_equal; ring.
- Qed.
-
-Lemma exp_nth_root j n e :
- exp (nth_root j (S n)) e = nth_root (e * j) (S n).
-Proof.
- unfold nth_root.
- rewrite de_moivre mult_INR.
- assert ((INR e * (2 * PI * INR j / INR n.+1)%R)%R = 2 * PI * (INR e * INR j)%R / INR n.+1).
- {
- replace (S n) with (n + 1)%nat by lia.
- unfold mul, inv; simpl.
- field.
- }
- by rewrite -H.
-Qed.
-
-Lemma exp_nth_root_comm j n e :
- exp (nth_root j (S n)) e = exp (nth_root e (S n)) j.
-Proof.
- do 2 rewrite exp_nth_root.
- f_equal.
- apply mulnC.
-Qed.
-
-Lemma nth_root_npow j n :
- exp (nth_root j (S n)) (S n) = 1.
-Proof.
- by rewrite exp_nth_root mulnC nth_root_2PI.
-Qed.
-
-Lemma minus_mod (j1 j2 n : nat) :
- j1 mod (S n) = j2 mod (S n) ->
- (j2 - j1) mod (S n) = 0%nat.
-Proof.
- intros eqq1.
- destruct (le_dec j1 j2).
- - generalize (Zdiv.Zminus_mod (Z.of_nat j2) (Z.of_nat j1) (Z.of_nat (S n)))
- ; intros HH.
- rewrite <- Nat2Z.inj_sub in HH by trivial.
- repeat rewrite <- Nat2Z.inj_mod in HH.
- rewrite -eqq1 Z.sub_diag Zdiv.Zmod_0_l in HH.
- apply (f_equal Z.to_nat) in HH.
- now rewrite Nat2Z.id in HH.
- - unfold subn, subn_rec.
- rewrite Minus.not_le_minus_0_stt; trivial.
- now apply Nat.mod_0_l.
-Qed.
-
-Lemma nth_root_mod j1 j2 n :
- j1 mod (S n) = j2 mod (S n) ->
- nth_root j1 (S n) = nth_root j2 (S n).
-Proof.
- intros.
- destruct (le_dec j1 j2).
- - assert (exists (k:nat), (j2 = j1 + k * (S n))%N).
- {
- generalize (Nat.div_mod_eq (j2 - j1) (S n)); intros.
- exists ((j2 - j1)/(S n))%N.
- rewrite minus_mod in H0; trivial; lia.
- }
- destruct H0.
- rewrite H0.
- now rewrite nth_root_2PI_plus.
- - assert (exists (k:nat), (j1 = j2 + k * (S n))%N).
- {
- generalize (Nat.div_mod_eq (j1 - j2) (S n)); intros.
- exists ((j1 - j2)/(S n))%N.
- rewrite minus_mod in H0; lia.
- }
- destruct H0.
- rewrite H0.
- now rewrite nth_root_2PI_plus.
- Qed.
-
-Lemma prim_nth_root j n :
- nth_root j (S n) = exp (nth_root 1 (S n)) j.
-Proof.
- rewrite exp_nth_root.
- f_equal.
- lia.
- Qed.
-
-Lemma nth_root_not_0 j n :
- nth_root j (S n) <> 0.
-Proof.
- rewrite /nth_root.
- generalize (cos_sin_0 (2 * PI * INR j / INR (S n))%R); intros.
- intros ?.
- apply H.
- split.
- - by apply (f_equal (fun c => Re c)) in H0.
- - by apply (f_equal (fun c => Im c)) in H0.
- Qed.
-
-Lemma cos1_sin0 (x : R) :
- cos x = 1 ->
- sin x = 0.
-Proof.
- intros eqq1.
- generalize (cos2 x).
- rewrite eqq1; intros eqq2.
- rewrite Rsqr_1 in eqq2.
- apply Rsqr_0_uniq.
- coq_lra.
-Qed.
-
-Section sin_cos.
- Local Open Scope R_scope.
-
-Lemma cosneg1_sin0 (x : R) :
- cos x = - 1 ->
- sin x = 0.
-Proof.
- intros eqq1.
- generalize (cos2 x).
- rewrite eqq1; intros eqq2.
- rewrite -Rsqr_neg Rsqr_1 in eqq2.
- apply Rsqr_0_uniq.
- coq_lra.
-Qed.
-
-Lemma cos_sin0_alt (x : R) :
- sin x = 0 <->
- Rsqr(cos x) = 1.
-Proof.
- split; intro eqq.
- - generalize (sin2_cos2 x); intros.
- rewrite eqq in H.
- rewrite Rsqr_0 in H.
- now rewrite Rplus_0_l in H.
- - generalize (sin2 x); intros.
- rewrite eqq in H.
- rewrite Rminus_eq_0 in H.
- now apply Rsqr_0_uniq in H.
-Qed.
-
-Lemma Rsqr_1_iff (x : R) :
- Rsqr x = 1 <->
- x = 1 \/ x = -1.
-Proof.
- generalize (Rsqr_eq x 1); intros.
- rewrite Rsqr_1 in H.
- split; intros.
- - now apply H.
- - unfold Rsqr.
- destruct H0; rewrite H0; coq_lra.
-Qed.
-
-Lemma cos_sin0 (x : R) :
- sin x = 0 <->
- cos x = 1 \/ cos x = -1.
-Proof.
- split; intros.
- - apply cos_sin0_alt in H.
- now apply Rsqr_1_iff in H.
- - apply Rsqr_1_iff in H.
- now apply cos_sin0_alt in H.
-Qed.
-
-
-Lemma cos_eq_1_aux_pos (x : R) :
- cos x = 1 ->
- exists k, x = (PI * IZR(k))%R.
-Proof.
- intros eqq1.
- generalize (cos1_sin0 _ eqq1); intros eqq2.
- apply sin_eq_0_0 in eqq2.
- destruct eqq2 as [k eqqk].
- exists k.
- unfold mul; simpl; coq_lra.
-Qed.
-
-Lemma cos_eq_1_aux_neg (x : R) :
- cos x = - 1 ->
- exists k, x = (PI * IZR(k))%R.
-Proof.
- intros eqq1.
- generalize (cosneg1_sin0 _ eqq1); intros eqq2.
- apply sin_eq_0_0 in eqq2.
- destruct eqq2 as [k eqqk].
- exists k.
- unfold mul; simpl; coq_lra.
-Qed.
-
-Lemma sin_eq_0_aux (x : R) :
- sin x = 0 ->
- exists k, x = (PI * IZR(k))%R.
-Proof.
- intros.
- apply cos_sin0 in H.
- destruct H.
- - now apply cos_eq_1_aux_pos.
- - now apply cos_eq_1_aux_neg.
-Qed.
-
-Lemma cos_eq_1_1 :
- forall k:Z,
- cos (IZR k * 2 * PI)%R = 1.
-Proof.
- intros k.
- assert (forall n, cos (INR n * 2 * PI) = 1%R). {
- intros n;induction n as [|n IHn].
- { change (INR 0) with Rdefinitions.R0.
- rewrite !Rmult_0_l.
- exact cos_0. }
- rewrite S_INR !Rmult_plus_distr_r cos_plus IHn.
- rewrite !Rmult_1_l cos_2PI sin_2PI Rmult_0_r Rminus_0_r.
- reflexivity.
- }
- destruct (Z.abs_or_opp_abs k).
- - replace (IZR k) with (INR (Z.to_nat k)).
- { apply H. }
- rewrite INR_IZR_INZ.
- f_equal.
- apply Z2Nat.id.
- lia.
- - replace (IZR k) with (- INR (Z.to_nat (- k)))%R.
- + by rewrite mulNr mulNr cos_neg H.
- + rewrite INR_IZR_INZ.
- rewrite <-opp_IZR. f_equal.
- lia.
-Qed.
-
-Lemma cos_lt_1 (x : R) :
- 0 < x ->
- x < 2*PI ->
- cos x < 1.
-Proof.
- intros.
- generalize (COS_bound x); intros.
- generalize PI_RGT_0; intro pi_gt.
- destruct H1.
- assert (cos x <> 1)%R.
- {
- unfold not.
- intros.
- generalize (cos_eq_1_aux_pos x H3); intros.
- destruct H4.
- rewrite H4 mulrC /mul /= in H0.
- apply Rmult_lt_reg_r in H0; trivial.
- rewrite H4 in H.
- replace (IZR Z0) with (Rmult PI 0)%R in H by coq_lra.
- rewrite /mul /= in H.
- apply Rmult_lt_reg_l in H; trivial.
- assert (x0 = Z.one).
- {
- apply lt_IZR in H.
- apply lt_IZR in H0.
- unfold Z.one.
- lia.
- }
- rewrite H5 /Z.one /IZR /IPR mulr1 in H4.
- rewrite H4 cos_PI /one /= in H3.
- coq_lra.
- }
- rewrite /one /= in H3.
- coq_lra.
- Qed.
-
-Lemma cos_eq_1 (x : R) :
- cos x = 1 ->
- exists (k:Z), x = (2 * PI * IZR(k))%R.
-Proof.
- intros Hx.
- assert (PI2_neq0: (2 * PI <> 0)%R).
- {
- generalize PI_neq0.
- unfold mul, one, zero, natmul, add; simpl.
- coq_lra.
- }
- destruct (euclidian_division x (2*PI) PI2_neq0) as (q & r & EQ & Hr & Hr').
- exists q.
- rewrite <- (Rplus_0_r (_*_)). subst.
- rewrite Rmult_comm.
- apply Rplus_eq_compat_l.
- rewrite cos_plus in Hx.
- assert (H : cos (IZR q * 2 * PI)%R = 1%R) by ( apply cos_eq_1_1; now exists q).
- rewrite -Rmult_assoc H /one /= Rmult_1_l sin_eq_0_1 in Hx.
- - rewrite Rmult_0_l Rminus_0_r in Hx.
- rewrite Rabs_right in Hr'.
- + destruct Hr as [Hr | ->]; trivial.
- exfalso.
- generalize (cos_lt_1 r Hr Hr'); intros.
- coq_lra.
- + generalize PI_RGT_0; coq_lra.
- - exists (Z.mul 2 q).
- rewrite mult_IZR.
- coq_lra.
- Qed.
-
-Lemma cos_eq_neg1 (x : R) :
- cos x = -1 ->
- exists k, x = (2 * PI * IZR(k) + PI)%R.
-Proof.
- intros eqq1.
- generalize (Rtrigo_facts.cos_pi_plus x); intros eqq2.
- rewrite eqq1 in eqq2.
- rewrite Ropp_involutive in eqq2.
- apply cos_eq_1 in eqq2.
- destruct eqq2 as [k eqq2].
- exists (Z.sub k 1)%Z.
- rewrite minus_IZR.
- replace (Rplus PI x) with (PI + x)%R in eqq2.
- - replace (Rminus (IZR k) 1) with ((IZR k) - 1)%R.
- + lra.
- + unfold add, opp, one; simpl; coq_lra.
- - unfold add; simpl; coq_lra.
-Qed.
-
-Lemma cos_eq_1_nneg (x : R) :
- cos x = 1 ->
- 0 <= x ->
- exists (k:nat), x = (2 * PI * INR(k))%R.
-Proof.
- intros.
- generalize (cos_eq_1 x H); intros.
- destruct H1.
- rewrite H1 in H0.
- replace (IZR Z0) with (2 * PI * 0)%R in H0.
- - apply Rmult_le_reg_l in H0.
- + unfold zero in H0; simpl in H0.
- apply le_IZR in H0.
- exists (Z.abs_nat x0).
- rewrite H1.
- do 2 f_equal.
- destruct x0; simpl; trivial; try lia.
- now rewrite INR_IPR.
- + unfold mul, one, natmul, add; simpl.
- generalize PI_RGT_0; coq_lra.
- - by rewrite mulr0.
-Qed.
-
-Lemma sin_cos_eq x y:
- sin x = sin y /\ cos x = cos y ->
- exists (k:Z),
- x = (y + 2 * PI * IZR(k))%R.
-Proof.
- intros.
- generalize (cos_minus x y); intros.
- destruct H.
- rewrite H H1 in H0.
- generalize (sin2_cos2 y); intros.
- rewrite Rplus_comm in H0.
- unfold Rsqr in H2.
- rewrite H2 in H0.
- apply cos_eq_1 in H0.
- destruct H0.
- exists x0.
- rewrite <- H0.
- unfold add; simpl.
- coq_lra.
-Qed.
-
-Lemma Pi2_neq0 :
- (2 * PI <> 0)%R.
-Proof.
- generalize PI_neq0.
- unfold mul, one, zero, natmul, add; simpl.
- coq_lra.
-Qed.
-
-Lemma Pi2_neq0_alt :
- is_true (2 * PI != 0).
-Proof.
- generalize Pi2_neq0.
- by case eqP.
-Qed.
-
-End sin_cos.
-
-Lemma nth_root_eq j k n :
- j mod (S n) = k mod (S n) <->
- nth_root j (S n) = nth_root k (S n).
-Proof.
- split; intros.
- - now apply nth_root_mod.
- - unfold nth_root in H.
- replace (S n) with (addn n 1) in H by lia.
- inversion H; clear H.
- pose (jj := (2 * PI * INR j)/ (INR (addn n 1))).
- pose (kk := (2 * PI * INR k)/ (INR (addn n 1))).
- generalize (sin_cos_eq jj kk); intros.
- destruct H.
- + split; trivial.
- + unfold jj, kk in H.
- clear H1 H2 jj kk.
- replace (2 * PI * INR k / INR (addn n 1) + 2 * PI * IZR x)%R with
- (2 * PI * (INR k / INR (addn n 1) + IZR x))%R in H by lra.
- replace (2 * PI * INR j / INR (addn n 1))%R with
- (2 * PI * (INR j / INR (addn n 1)))%R in H by lra.
- apply (f_equal (fun r => (inv (2 * PI)) * r))%R in H.
- generalize Pi2_neq0_alt; intros.
- rewrite mulrDr -(mulrA _ (INR k) _) !(mulrA _ (2 * PI) _) (mulrC _ (2 * PI)) divff in H; trivial.
- rewrite !mul1r in H.
- apply (f_equal (fun r => r * (INR (addn n 1)))) in H.
- replace (addn n 1) with (S n) in H by lia.
- generalize (S_INR_n0 n); intros.
- rewrite mulrDl -!mulrA !(mulrC _ (INR (S n))) divff in H; trivial.
- rewrite !mulr1 mulrC !INR_IZR_INZ in H.
- repeat rewrite <- mult_IZR in H.
- repeat rewrite <- plus_IZR in H.
- apply eq_IZR in H.
- apply Nat2Z.inj.
- rewrite !Nat2Z.inj_mod H.
- transitivity (Z.modulo (Z.add (Z.of_nat k) (Z.mul x (Z.of_nat (S n)))) (Z.of_nat (S n))).
- * by f_equal.
- * by rewrite Zdiv.Z_mod_plus_full.
-Qed.
-
-Lemma nth_root_pow_eq (n j k : nat) :
- 0%N <> n ->
- forall (e1 e2 : nat),
- (nth_root j n) ^+ e1 =
- (nth_root k n) ^+ e2 <->
- (e1 * j) mod n = (e2 * k) mod n.
-Proof.
- intros.
- destruct n; try lia.
- by rewrite !exp_nth_root -nth_root_eq.
-Qed.
-
-Lemma nth_root_1_iff n j :
- nth_root j (S n) = 1 <-> j mod (S n) = 0%N.
-Proof.
- rewrite <- (nth_root_0 n).
- rewrite <- nth_root_eq.
- replace (0 mod S n) with 0%N; try easy.
- rewrite Nat.mod_small; lia.
-Qed.
-
-Lemma nth_root_not_1 j n :
- j mod (S n) <> 0%N ->
- nth_root j (S n) <> 1.
-Proof.
- intros ??.
- rewrite nth_root_1_iff in H0.
- by rewrite H0 in H.
-Qed.
-
-Lemma pow_nth_root_prim n :
- exp (nth_root 1 (S n)) (S n) = 1.
-Proof.
- unfold nth_root.
- rewrite de_moivre.
- replace (INR n.+1 * (2 * PI * INR 1 / INR n.+1)%R) with (2 * PI)%R.
- - by rewrite cos_2PI sin_2PI.
- - rewrite [INR 1]INRE mulr1.
- rewrite [INR n.+1 * _]mulrC.
- rewrite -!mulrA mulVf ?mulr1//.
- replace (zero (Field.zmodType R_fieldType)) with (INR 0) by trivial.
- by rewrite !INRE ssrnum.Num.Theory.eqr_nat.
-Qed.
-
-Lemma pow_nth_root j n :
- exp (nth_root j (S n)) (S n) = 1.
-Proof.
- by rewrite prim_nth_root -exprM mulnC exprM pow_nth_root_prim expr1n.
-Qed.
-
-Lemma nth_root_mul j k n :
- mul (nth_root j (S n)) (nth_root k (S n)) = nth_root (j + k) (S n).
-Proof.
- intros.
- rewrite (prim_nth_root k _).
- rewrite (prim_nth_root j _).
- rewrite (prim_nth_root (j + k) _).
- now rewrite <- exprD.
- Qed.
-
-Lemma nth_root_Sn n :
- nth_root (S n) (S n) = 1.
-Proof.
- by rewrite prim_nth_root nth_root_npow.
-Qed.
-
-Lemma Cinv_r (x : R[i]) :
- x <> 0 ->
- x * (inv x) = 1.
-Proof.
- intros.
- rewrite divff //.
- by case eqP.
-Qed.
-
-Lemma Cinv_l (x : R[i]) :
- x <> 0 ->
- (inv x) * x = 1.
-Proof.
- intros.
- rewrite mulrC Cinv_r //.
-Qed.
-
-Lemma exp_sub_r (c : R[i]) (n m : nat):
- (le m n) ->
- c <> 0 ->
- exp c (n - m) = (exp c n) / (exp c m).
-Proof.
- intros.
- destruct H.
- - rewrite subnn expr0 Cinv_r//.
- generalize (@expf_neq0 _ c m).
- case: eqP => //.
- case: eqP => //.
- intuition.
- - rewrite expfB//.
- case: leP => //.
- lia.
-Qed.
-
-Lemma nth_root_diff j k n :
- (le j k) ->
- (nth_root k (S n)) / (nth_root j (S n)) = nth_root (k-j) (S n).
-Proof.
- intros.
- rewrite (prim_nth_root k _).
- rewrite (prim_nth_root j _).
- rewrite (prim_nth_root (k-j) _).
- rewrite exp_sub_r; trivial.
- apply nth_root_not_0.
-Qed.
-
-Lemma nth_root_inv j n :
- inv (nth_root j (S n)) = nth_root (S n - (j mod S n)) (S n).
-Proof.
- generalize (nth_root_diff (j mod S n) (S n) n); intros.
- rewrite <- H.
- - rewrite nth_root_Sn mul1r.
- f_equal.
- apply (nth_root_mod j (j mod S n) n).
- rewrite Nat.mod_mod; try lia.
- - assert (S n <> 0%N) by lia.
- generalize (Nat.mod_upper_bound j (S n) H0); lia.
- Qed.
-
-Lemma nth_root_div j k n :
- (nth_root j (S n)) / (nth_root k (S n)) =
- nth_root (j + (S n - (k mod S n))) (S n).
-Proof.
- rewrite nth_root_inv.
- apply nth_root_mul.
-Qed.
-
-Definition Cmod (x : R[i]) := (* ComplexField.Normc.normc. *)
- let: a +i* b := x in sqrt (exp a 2 + exp b 2).
-
-Lemma nth_root_Cmod j n :
- Cmod (nth_root j (S n)) = 1%R.
-Proof.
- unfold Cmod, nth_root, fst, snd.
- rewrite Rplus_comm /one /= -sqrt_1.
- f_equal.
- by rewrite -!RpowE -!Rsqr_pow2 sin2_cos2.
-Qed.
-
-Lemma Cmod_Cconj_alt (c : R[i]) :
- let: a +i* b :=c in
- c * (conjc c) = (a^+2 + b^+2) +i* 0.
-Proof.
- destruct c.
- unfold mul; simpl.
- f_equal; lra.
-Qed.
-
-Lemma Cmod_Cconj (c : R[i]) :
- c * (conjc c) = RtoC (Rsqr (Cmod c)).
-Proof.
- generalize (Cmod_Cconj_alt c); intros.
- unfold Cmod, fst, snd, RtoC.
- unfold RtoC in H.
- destruct c.
- rewrite H.
- f_equal.
- rewrite -!RpowE Rsqr_sqrt //.
- apply Rplus_le_le_0_compat; apply pow2_ge_0.
-Qed.
-
-Lemma nth_root_conj j n :
- conjc (nth_root j (S n)) = inv (nth_root j (S n)).
-Proof.
- generalize (Cmod_Cconj (nth_root j (S n))); intros.
- rewrite nth_root_Cmod Rsqr_1 in H.
- apply (f_equal (fun c => mul (inv (nth_root j (S n))) c)) in H.
- rewrite /RtoC mulr1 mulrA Cinv_l in H.
- - now rewrite mul1r in H.
- - by apply nth_root_not_0.
-Qed.
-
-Lemma nth_root_conj_alt j n :
- conjc (nth_root j (S n)) = nth_root (S n - j mod (S n)) (S n).
-Proof.
- by rewrite nth_root_conj nth_root_inv.
-Qed.
-
-Lemma nth_root_half_pow_aux n :
- exp (nth_root (S n) (2 * (S n))) 2 = 1.
-Proof.
- replace (muln 2 (S n)) with (S (2 * n + 1)) by lia.
- rewrite exp_nth_root.
- do 2 replace (muln 2 (S n)) with (S (2 * n + 1)) by lia.
- now rewrite nth_root_Sn.
-Qed.
-
-Lemma pow2_inv x y : (x ^+ 2)%R = y -> Rabs x = sqrt y.
-Proof.
- intros eqq1.
- apply (f_equal sqrt) in eqq1.
- rewrite expr2 in eqq1.
- rewrite -eqq1 -sqrt_Rsqr_abs //.
-Qed.
-
-Lemma Rabs_pm_l x y : Rabs x = y -> x = y \/ (- x)%R = y.
-Proof.
- unfold Rabs.
- destruct (Rcase_abs); [right|left]; rewrite -H //.
-Qed.
-
-Lemma Rabs_pm_r x y : Rabs x = y -> x = y \/ x = (- y)%R.
-Proof.
- unfold Rabs.
- destruct (Rcase_abs); [right|left]; rewrite -H //.
- unfold opp; simpl.
- by rewrite Ropp_involutive.
-Qed.
-
-Lemma cmult_real (c : R[i]) :
- Im (c * c) = 0 <->
- Re c = 0 \/ Im c = 0.
-Proof.
- destruct c.
- simpl.
- split; intros.
- - assert (Re * Im = 0) by lra.
- rewrite /mul /zero /= in H0.
- apply Rmult_integral in H0.
- by rewrite /zero /=.
- - destruct H; rewrite H; lra.
-Qed.
-
-Lemma Cpow_2 (c : R[i]) :
- exp c 2 = 1 -> c = 1 \/ c = - 1.
-Proof.
- rewrite expr2; intros.
- assert (Im (c * c) = 0).
- {
- by rewrite H /=.
- }
- rewrite cmult_real in H0.
- destruct c.
- simpl in H0.
- destruct H0.
- - rewrite H0 /mul /= !mul0r mulr0 !add0r in H.
- injection H; intros.
- rewrite /mul /one /opp /= in H1.
- generalize (pow2_ge_0 Im); intros.
- rewrite /pow Rmult_1_r in H2.
- coq_lra.
- - rewrite H0 /mul /= !mul0r !mulr0 oppr0 !addr0 in H.
- rewrite H0.
- injection H; intros.
- generalize (Rsqr_1_iff Re); intros.
- unfold Rsqr in H2.
- rewrite /one /mul /= in H1.
- rewrite H2 in H1.
- destruct H1.
- + left.
- by rewrite H1 /one /=.
- + right.
- by rewrite H1 /one /= /opp /= oppr0 /opp /one /= -IZR_NEG.
-Qed.
-
-Lemma nth_root_half_pow n :
- nth_root (S n) (2 * (S n)) = -1.
-Proof.
- generalize (nth_root_half_pow_aux n); intros.
- apply Cpow_2 in H.
- destruct H; trivial.
- replace (muln 2 (S n)) with (S (2 * n + 1)) in H by lia.
- generalize (nth_root_not_1 (S n) (2*n+1)); intros.
- assert (S n mod S (2 * n + 1) <> 0%N).
- {
- rewrite Nat.mod_small; lia.
- }
- tauto.
-Qed.
-
-Lemma pow2_S (j:nat) :
- exists (k : nat), expn 2 j = S k.
-Proof.
- exists (2^j-1)%nat.
- lia.
-Qed.
-
-Lemma odd_roots_prim j n :
- exp (nth_root (2 * j + 1) (2 ^ (S n))) (2^n) = -1.
-Proof.
- generalize (pow2_S (S n)); intros.
- destruct H.
- rewrite H.
- rewrite exp_nth_root.
- rewrite <- H.
- assert ((2 ^ n * (2 * j + 1) mod (2 ^ S n)) =
- (2 ^ n mod (2 ^ S n)))%N.
- {
- replace (2 ^n * (2 * j + 1))%N with (2 ^ n + j*(2 * 2^n))%N by lia.
- replace (2 ^ (S n))%N with (2 * 2^n)%N.
- - rewrite Nat.mod_add; try lia.
- - by rewrite expnS.
- }
- rewrite H in H0.
- apply nth_root_mod in H0.
- rewrite <- H in H0.
- rewrite H0.
- generalize (pow2_S n); intros.
- destruct H1.
- simpl.
- replace (2 ^ n + (2 ^n + 0))%N with (2 * 2^n)%N by lia.
- rewrite expnS H1.
- apply nth_root_half_pow.
-Qed.
-
-Lemma mult_conj_root j n :
- (nth_root j (S n)) * (conjc (nth_root j (S n))) = 1.
-Proof.
- rewrite nth_root_conj Cinv_r //.
- by apply nth_root_not_0.
-Qed.
-
-Lemma nth_root_half n :
- nth_root (2 ^n) (2 ^ (S n)) = - 1.
-Proof.
- destruct (pow2_S (S n)).
- generalize (odd_roots_prim 0 n); intros.
- rewrite H exp_nth_root -H in H0.
- by rewrite muln0 add0n muln1 in H0.
-Qed.
-
-Lemma nth_root_opp j n :
- (nth_root j (2 ^ (S n))) + (nth_root (j + 2^n) (2 ^ (S n))) = 0.
-Proof.
- destruct (pow2_S (S n)).
- by rewrite H -nth_root_mul -H nth_root_half mulrN1 addrC addNr.
-Qed.
-
-Definition Nat2Zinj := (Nat2Z.inj_mod, Nat2Z.inj_mul, Nat2Z.inj_add, Nat2Z.inj_sub).
-
-Lemma inv_pow_nth_root j k :
- exp (nth_root j (S k)) k = inv (nth_root j (S k)).
-Proof.
- rewrite nth_root_inv exp_nth_root -nth_root_eq.
- apply Nat2Z.inj.
- rewrite !Nat2Zinj.
- - rewrite Zdiv.Zmult_mod Zdiv.Zminus_mod.
- rewrite Zdiv.Z_mod_same_full Zdiv.Zmod_mod.
- rewrite Z.sub_0_l Z.opp_eq_mul_m1 Z.mul_comm.
- symmetry.
- rewrite Zdiv.Zmult_mod !Zdiv.Zmod_mod.
- f_equal.
- f_equal.
- replace (Zneg xH)%Z with (Z.sub (Z.of_nat k) (Z.of_nat (S k))).
- + rewrite Zdiv.Zminus_mod Zdiv.Z_mod_same_full.
- f_equal.
- rewrite Z.mod_small; lia.
- + rewrite Nat2Z.inj_succ; lia.
- - generalize (Nat.mod_upper_bound j (S k)); lia.
-Qed.
-
-Lemma conj_pow_nth_root j k :
- exp (nth_root j (S k)) k = conjc (nth_root j (S k)).
-Proof.
- by rewrite nth_root_conj inv_pow_nth_root.
-Qed.
-
-
-(* testing notations *)
-Definition C0': R[i] := 0.
-Definition C1': R[i] := 1.
-Definition Cplus' (x y : R[i]) := x + y.
-Definition Cmult' (x y : R[i]) := x * y.
-Definition Cexp' (x : R[i]) (n : nat) := x ^+ n.
-Definition Cdiv' (x y : R[i]) := x / y.
-Definition Cinv' (x : R[i]) := x^-1.
-
diff --git a/coq/FHE/polyinterp.v b/coq/FHE/polyinterp.v
deleted file mode 100644
index 620b5ec2..00000000
--- a/coq/FHE/polyinterp.v
+++ /dev/null
@@ -1,3325 +0,0 @@
-Require Import Reals Permutation Morphisms.
-Require Import Coquelicot.Complex.
-Require Import List.
-Require Import Lra Lia.
-Require Import Utils.
-Require Import Vector.
-
-Set Bullet Behavior "Strict Subproofs".
-
-Lemma Forall2_nth_error_iff {A B} (P:A->B->Prop) (l1 : list A) (l2: list B) :
- (forall (i : nat), match nth_error l1 i, nth_error l2 i with
- | Some a, Some b => P a b
- | None, None => True
- | _, _ => False
- end
- ) <->
- Forall2 P l1 l2.
-Proof.
- split.
- - revert l2; induction l1; destruct l2; simpl in *; trivial; intros HH.
- + specialize (HH 0); simpl in HH; contradiction.
- + specialize (HH 0); simpl in HH; contradiction.
- + constructor.
- * now specialize (HH 0); simpl in HH.
- * apply IHl1; intros i.
- now specialize (HH (S i)).
- - induction 1; intros [| i]; simpl; trivial.
- apply IHForall2.
-Qed.
-
-Lemma nth_error_eqs {A} (l1 l2 : list A) :
- (forall (i : nat), nth_error l1 i = nth_error l2 i) ->
- l1 = l2.
-Proof.
- intros HH.
- apply Forall2_eq.
- apply Forall2_nth_error_iff; intros i.
- rewrite HH.
- now destruct (nth_error l2 i).
-Qed.
-
-Lemma nth_error_eqs_len {A} (l1 l2 : list A) :
- length l1 = length l2 ->
- (forall (i : nat), i < length l1 -> nth_error l1 i = nth_error l2 i) ->
- l1 = l2.
-Proof.
- intros eqq1 HH.
- apply nth_error_eqs; intros.
- destruct (lt_dec i (length l1)); auto 2.
- destruct (nth_error_None l2 i) as [_ eqq2].
- rewrite eqq2 by lia.
- apply nth_error_None; lia.
-Qed.
-
-Lemma rev_nth_error {A} (l:list A) (n:nat) :
- n < length l -> nth_error (rev l) n = nth_error l (length l - S n).
-Proof.
- revert n.
- induction l using rev_ind; simpl; [lia |]; intros n nlt.
- rewrite rev_app_distr, app_length; simpl.
- destruct n.
- - simpl.
- rewrite nth_error_app2 by lia.
- now replace (length l + 1 - 1 - length l) with 0 by lia.
- - simpl.
- rewrite app_length in nlt; simpl in nlt.
- rewrite IHl by lia.
- rewrite nth_error_app1 by lia.
- f_equal.
- lia.
-Qed.
-
-Lemma seq_nth_error [len : nat] (start : nat) [n : nat] :
- n < len -> nth_error (seq start len) n = Some (start + n).
-Proof.
- intros nlt.
- rewrite (nth_error_nth' _ 0).
- - now rewrite seq_nth.
- - now rewrite seq_length.
-Qed.
-
-Lemma rev_seq start n :
- rev (seq start n) = map (fun i => (n + 2 * start - S i)) (seq start n).
-Proof.
- apply nth_error_eqs_len.
- - now rewrite rev_length, map_length.
- - intros i ilt.
- rewrite rev_length in ilt.
- rewrite rev_nth_error by trivial.
- rewrite seq_length in ilt.
- rewrite nth_error_map.
- repeat rewrite seq_nth_error; trivial
- ; rewrite seq_length.
- + simpl; f_equal; lia.
- + lia.
-Qed.
-
-Lemma map_skipn_S_error {A:Type} (l:list A) n a :
- nth_error l n = Some a ->
- skipn n l = a :: skipn (S n) l.
-Proof.
- revert l.
- induction n; simpl; destruct l; simpl; intros; try congruence.
- rewrite IHn; trivial.
-Qed.
-
-Lemma list_cons_app_hyp_even {A} (P:list A->Prop) (R:A->A->Prop)
- (Pnil : P nil)
- (Psmoosh : forall a z l, R a z -> R z a -> P l -> P (a::(l ++ z ::nil)))
- : forall n (l:list A), length l = 2 * n -> Forall2 R l (rev l) -> P l.
-Proof.
- induction n; simpl.
- - intros [|]; simpl; congruence.
- - intros.
- destruct l; simpl in *; [lia |].
- destruct l using rev_ind; simpl in *; [lia |].
- rewrite app_length in H; simpl in H.
- assert (eqq1:length l = 2 * n) by lia.
- specialize (IHn _ eqq1).
- rewrite rev_app_distr in H0.
- simpl in H0.
- invcs H0.
- apply Forall2_app_tail_inv in H6.
- destruct H6.
- auto.
-Qed.
-
-Lemma list_cons_app_hyp_odd {A} (P:list A->Prop) (R:A->A->Prop)
- (Psingle : forall a, R a a -> P (a :: nil))
- (Psmoosh : forall a z l, R a z -> R z a -> P l -> P (a::(l ++ z ::nil)))
- : forall n (l:list A), length l = 2 * n + 1 -> Forall2 R l (rev l) -> P l.
-Proof.
- induction n; simpl.
- - intros [|]; simpl; try lia.
- destruct l; simpl; try lia; intros.
- invcs H0.
- auto.
- - intros.
- destruct l; simpl in *; [lia |].
- destruct l using rev_ind; simpl in *; [lia |].
- rewrite app_length in H; simpl in H.
- assert (eqq1:length l = 2 * n + 1) by lia.
- specialize (IHn _ eqq1).
- rewrite rev_app_distr in H0.
- simpl in H0.
- invcs H0.
- apply Forall2_app_tail_inv in H6.
- destruct H6.
- auto.
-Qed.
-
-Lemma list_cons_app_hyp {A} (P:list A->Prop) (R:A->A->Prop)
- (Pnil : P nil)
- (Psingle : forall a, R a a -> P (a :: nil))
- (Psmoosh : forall a z l, R a z -> R z a -> P l -> P (a::(l ++ z ::nil)))
- : forall (l:list A), Forall2 R l (rev l) -> P l.
-Proof.
- intros.
- destruct (NPeano.Nat.Even_Odd_dec (length l)).
- - destruct e.
- eapply list_cons_app_hyp_even; eauto.
- - destruct o.
- eapply list_cons_app_hyp_odd; eauto.
-Qed.
-
-Lemma nth_error_firstn_in {A} (l:list A) n i :
- i < n ->
- nth_error (firstn n l) i = nth_error l i.
-Proof.
- revert n l.
- induction i; simpl; intros
- ; destruct n; destruct l; simpl; trivial; try lia.
- apply IHi; lia.
-Qed.
-
-Lemma nth_error_firstn_nin {A} (l:list A) n i :
- i >= n ->
- nth_error (firstn n l) i = None.
-Proof.
- intros.
- apply nth_error_None.
- rewrite firstn_length.
- lia.
-Qed.
-
-Lemma nth_error_skipn {A} (l:list A) n i :
- nth_error (skipn n l) i = nth_error l (n + i).
-Proof.
- revert l i.
- induction n; simpl; trivial; intros.
- destruct l; trivial.
- now rewrite nth_error_nil.
-Qed.
-
-Lemma firstn_seq n1 n2 start :
- n1 <= n2 ->
- firstn n1 (seq start n2) = seq start n1.
-Proof.
- intros.
- apply nth_error_eqs_len.
- - rewrite firstn_length; repeat rewrite seq_length.
- lia.
- - intros i ilt.
- rewrite firstn_length, seq_length in ilt.
- rewrite nth_error_firstn_in by lia.
- now repeat rewrite seq_nth_error by lia.
-Qed.
-
-Lemma skipn_seq n1 n2 start :
- skipn n1 (seq start n2) = seq (start+n1) (n2-n1).
-Proof.
- intros.
- apply nth_error_eqs_len.
- - now rewrite skipn_length; repeat rewrite seq_length.
- - intros i ilt.
- rewrite skipn_length in ilt; repeat rewrite seq_length in ilt.
- rewrite nth_error_skipn.
- repeat rewrite seq_nth_error by lia.
- f_equal; lia.
-Qed.
-
-Lemma combine_nth_error [A B : Type] (l : list A) (l' : list B) (n : nat) :
- length l = length l' -> nth_error (combine l l') n = match nth_error l n, nth_error l' n with
- | Some x, Some y => Some (x,y)
- | _, _ => None
- end.
-Proof.
- revert l l'.
- induction n; destruct l; destruct l'; simpl; try congruence.
- intros.
- apply IHn; congruence.
-Qed.
-
-
-(* represent complex number as pair *)
-Definition nth_root (j n : nat) : C :=
- let c := (2*PI*INR(j)/INR(n))%R in
- (cos c, sin c).
-
-Lemma S_INR_not_0 n :
- INR (S n) <> 0%R.
-Proof.
- rewrite S_INR.
- generalize (pos_INR n).
- lra.
-Qed.
-
-
-Lemma nth_root_0 n :
- nth_root 0 (S n) = R1.
-Proof.
- unfold nth_root.
- assert ((2 * PI * INR 0 / INR (S n))%R = 0%R).
- {
- unfold INR at 1.
- field.
- apply S_INR_not_0.
- }
- rewrite H.
- now rewrite cos_0, sin_0.
-Qed.
-
-Lemma nth_root_2PI n j :
- nth_root (j * (S n)) (S n) = R1.
-Proof.
- unfold nth_root.
- rewrite mult_INR.
- replace (2 * PI * (INR j * INR (S n)) / INR (S n))%R with
- (0 + 2 * INR(j) * PI)%R.
- - rewrite cos_period, sin_period.
- now rewrite cos_0, sin_0.
- - field.
- apply S_INR_not_0.
-Qed.
-
-
-Lemma nth_root_2PI_plus n j k :
- nth_root (j + k * (S n)) (S n) = nth_root j (S n).
-Proof.
- unfold nth_root.
- replace (2 * PI * INR (j + k * S n) / INR (S n))%R with
- (2 * PI * INR(j)/INR(S n) + 2 * INR k * PI)%R.
- - now rewrite cos_period, sin_period.
- - rewrite plus_INR, mult_INR.
- field.
- apply S_INR_not_0.
- Qed.
-
-Definition nth_roots (n:nat) :=
- map (fun j => nth_root j n) (seq 0 n).
-
-Lemma de_moive (x : R) (n : nat) :
- Cpow (cos x, sin x) n = (cos ((INR n) * x), sin ((INR n) * x)).
-Proof.
- induction n.
- - simpl.
- rewrite Rmult_0_l.
- now rewrite cos_0, sin_0.
- - simpl Cpow.
- rewrite IHn.
- unfold Cmult, fst, snd.
- replace (INR (S n) * x)%R with (x + (INR n) * x)%R.
- + rewrite cos_plus, sin_plus.
- f_equal.
- lra.
- + rewrite S_INR.
- lra.
- Qed.
-
-Lemma Cpow_nth_root j n e :
- Cpow (nth_root j (S n)) e = nth_root (e * j) (S n).
-Proof.
- unfold nth_root.
- rewrite de_moive.
- rewrite mult_INR.
- do 2 f_equal; field; apply S_INR_not_0.
-Qed.
-
-Lemma Cpow_nth_root_comm j n e :
- Cpow (nth_root j (S n)) e = Cpow (nth_root e (S n)) j.
-Proof.
- do 2 rewrite Cpow_nth_root.
- f_equal.
- lia.
-Qed.
-
-Lemma nth_root_npow j n :
- Cpow (nth_root j (S n)) (S n) = RtoC R1.
-Proof.
- rewrite Cpow_nth_root.
- replace (S n * j) with (j * S n) by lia.
- now rewrite nth_root_2PI.
-Qed.
-
-Lemma minus_mod (j1 j2 n : nat) :
- j1 mod (S n) = j2 mod (S n) ->
- (j2 - j1) mod (S n) = 0.
-Proof.
- intros eqq1.
- destruct (le_dec j1 j2).
- - generalize (Zdiv.Zminus_mod (Z.of_nat j2) (Z.of_nat j1) (Z.of_nat (S n)))
- ; intros HH.
- rewrite <- Nat2Z.inj_sub in HH by trivial.
- repeat rewrite <- Nat2Z.inj_mod in HH.
- rewrite <- eqq1 in HH.
- rewrite Z.sub_diag in HH.
- rewrite Zdiv.Zmod_0_l in HH.
- apply (f_equal Z.to_nat) in HH.
- now rewrite Nat2Z.id in HH.
- - rewrite Minus.not_le_minus_0_stt by trivial.
- now apply Nat.mod_0_l.
-Qed.
-
-Lemma nth_root_mod j1 j2 n :
- j1 mod (S n) = j2 mod (S n) ->
- nth_root j1 (S n) = nth_root j2 (S n).
-Proof.
- intros.
- destruct (le_dec j1 j2).
- - assert (exists (k:nat), j2 = j1 + k * (S n)).
- {
- generalize (Nat.div_mod_eq (j2 - j1) (S n)); intros.
- exists ((j2 - j1)/(S n)).
- rewrite minus_mod in H0; trivial; lia.
- }
- destruct H0.
- rewrite H0.
- now rewrite nth_root_2PI_plus.
- - assert (exists (k:nat), j1 = j2 + k * (S n)).
- {
- generalize (Nat.div_mod_eq (j1 - j2) (S n)); intros.
- exists ((j1 - j2)/(S n)).
- rewrite minus_mod in H0; trivial; lia.
- }
- destruct H0.
- rewrite H0.
- now rewrite nth_root_2PI_plus.
- Qed.
-
-Fixpoint list_Cplus (l : list C) : C :=
- match l with
- | nil => R0
- | a :: l' => Cplus a (list_Cplus l')
- end.
-
-Lemma list_Cplus_mult_l c l :
- (list_Cplus (map (fun a => c * a) l) = c * list_Cplus l)%C.
-Proof.
- induction l; simpl.
- - ring.
- - rewrite IHl.
- ring.
-Qed.
-
-Lemma list_cplus_mult_seq_comm c (p : list R) :
-list_Cplus
- (map (fun '(a0, b) => (RtoC a0 * b)%C)
- (combine p (map (fun j : nat => (c ^ j)%C) (seq 1 (length p))))) =
- (c *
- list_Cplus
- (map (fun '(a0, b) => RtoC a0 * b)
- (combine p (map (fun j : nat => c ^ j) (seq 0 (length p))))))%C.
-Proof.
- rewrite <- list_Cplus_mult_l.
- rewrite <- seq_shift.
- rewrite map_map.
- rewrite <- (map_id p) at 1 3.
- repeat rewrite combine_map.
- repeat rewrite map_map.
- f_equal.
- apply map_ext; intros [??]; simpl.
- ring.
-Qed.
-
-Lemma list_Cplus_Re (l : list C) :
- Re (list_Cplus l) = list_sum (map Re l).
-Proof.
- induction l.
- - now simpl.
- - simpl.
- now rewrite <- IHl.
-Qed.
-
-Lemma prim_nth_root j n :
- nth_root j (S n) = Cpow (nth_root 1 (S n)) j.
-Proof.
- rewrite Cpow_nth_root.
- f_equal.
- lia.
- Qed.
-
-Lemma nth_root_not_0 j n :
- nth_root j (S n) <> R0.
-Proof.
- unfold nth_root.
- unfold RtoC.
- generalize cos_sin_0; intros.
- specialize (H (2 * PI * INR j / INR (S n))%R).
- replace R0 with 0%R by lra.
- unfold not.
- intros.
- apply H.
- split.
- - apply (f_equal (fun c => fst c)) in H0.
- now unfold fst in H0.
- - apply (f_equal (fun c => snd c)) in H0.
- now unfold snd in H0.
- Qed.
-
-
-Lemma cos1_sin0 (x : R) :
- cos x = R1 ->
- sin x = R0.
-Proof.
- intros eqq1.
- generalize (cos2 x).
- rewrite eqq1; intros eqq2.
- replace R1 with 1%R in eqq2 by trivial.
- rewrite Rsqr_1 in eqq2.
- apply Rsqr_0_uniq.
- lra.
-Qed.
-
-Lemma cosneg1_sin0 (x : R) :
- cos x = Ropp R1 ->
- sin x = R0.
-Proof.
- intros eqq1.
- generalize (cos2 x).
- rewrite eqq1; intros eqq2.
- replace R1 with 1%R in eqq2 by trivial.
- rewrite <- Rsqr_neg in eqq2.
- rewrite Rsqr_1 in eqq2.
- apply Rsqr_0_uniq.
- lra.
-Qed.
-
-Lemma cos_eq_1_aux_pos (x : R) :
- cos x = R1 ->
- exists k, x = (PI * IZR(k))%R.
-Proof.
- intros eqq1.
- generalize (cos1_sin0 _ eqq1); intros eqq2.
- apply sin_eq_0_0 in eqq2.
- destruct eqq2 as [k eqqk].
- exists k.
- lra.
-Qed.
-
-Lemma cos_eq_1_aux_neg (x : R) :
- cos x = Ropp R1 ->
- exists k, x = (PI * IZR(k))%R.
-Proof.
- intros eqq1.
- generalize (cosneg1_sin0 _ eqq1); intros eqq2.
- apply sin_eq_0_0 in eqq2.
- destruct eqq2 as [k eqqk].
- exists k.
- lra.
-Qed.
-
-Lemma cos_eq_1 (x : R) :
- cos x = R1 ->
- exists k, x = (2 * PI * IZR(k))%R.
-Proof.
- intros eqq1.
- destruct (cos_eq_1_aux_pos _ eqq1) as [kk eqq2]; subst.
- assert (cutter:(forall kk, ((0 <= kk)%Z -> cos (PI * IZR kk) = R1 -> exists k : Z, (PI * IZR kk)%R = (2 * PI * IZR k)%R)) -> (forall kk, (cos (PI * IZR kk) = R1 -> (exists k : Z, (PI * IZR kk)%R = (2 * PI * IZR k)%R
- )))).
- {
- clear.
- intros HH kk eqq1.
- destruct (Z_le_gt_dec 0 kk); [eauto |].
- destruct (HH (Z.opp kk)%Z).
- - lia.
- - rewrite opp_IZR.
- replace (PI * - IZR kk)%R with (- (PI * IZR kk))%R by lra.
- now rewrite cos_neg.
- - exists (Z.opp x).
- rewrite opp_IZR in H |- *.
- lra.
- }
-
- apply cutter; trivial; clear.
- intros kk kk_nneg eqq1.
- destruct (Zeven_odd_dec kk).
- - destruct (Zeven_ex _ z); subst.
- exists x.
- rewrite mult_IZR.
- lra.
- - destruct (Zodd_ex _ z); subst.
- rewrite plus_IZR, mult_IZR in eqq1.
- replace ((PI * (2 * IZR x + 1))%R) with
- (2 * IZR x * PI + PI)%R in eqq1 by lra.
- rewrite neg_cos in eqq1.
- assert (eqq2: cos (2 * IZR x * PI)%R = Ropp R1) by lra.
- generalize (cos_period 0 (Z.to_nat x)); intros HH.
- rewrite cos_0 in HH.
- rewrite INR_IZR_INZ in HH.
- rewrite Z2Nat.id in HH by lia.
- replace (2 * IZR x * PI)%R with (0 + 2 * IZR x * PI)%R in eqq2 by lra.
- lra.
-Qed.
-
-Lemma cos_eq_neg1 (x : R) :
- cos x = Ropp R1 ->
- exists k, x = (2 * PI * IZR(k) + PI)%R.
-Proof.
- intros eqq1.
- generalize (Rtrigo_facts.cos_pi_plus x); intros eqq2.
- rewrite eqq1 in eqq2.
- rewrite Ropp_involutive in eqq2.
- apply cos_eq_1 in eqq2.
- destruct eqq2 as [k eqq2].
- exists (k-1)%Z.
- rewrite minus_IZR.
- lra.
-Qed.
-
-Lemma cos_eq_1_1 : forall x:R, (exists k : Z, x = (IZR k * 2 * PI)%R) -> cos x = 1%R.
-Proof.
- assert (forall n, cos (INR n * 2 * PI) = 1%R). {
- intros n;induction n as [|n IHn].
- { change (INR 0) with 0%R.
- replace (0 * 2 * PI)%R with 0%R by ring.
- exact cos_0. }
- rewrite S_INR.
- replace ((INR n + 1) * 2 * PI)%R with ((INR n) * 2 * PI + 2 * PI)%R by ring.
- rewrite cos_plus, IHn, cos_2PI, sin_2PI.
- ring.
- }
- intros x [k Hx].
- rewrite Hx;clear x Hx.
- destruct (Z.abs_or_opp_abs k).
- - replace (IZR k) with (INR (Z.to_nat k)).
- { apply H. }
- rewrite INR_IZR_INZ.
- f_equal.
- apply Z2Nat.id.
- lia.
- - replace (IZR k) with (- INR (Z.to_nat (- k)))%R.
- { replace (- INR (Z.to_nat (- k)) * 2 * PI)%R with (- (INR (Z.to_nat (- k)) * 2 * PI))%R by ring.
- rewrite cos_neg.
- rewrite H;ring. }
- rewrite INR_IZR_INZ.
- rewrite <-opp_IZR. f_equal.
- lia.
-Qed.
-
-Lemma cos_lt_1 (x : R) :
- (0 < x)%R ->
- (x < 2*PI)%R ->
- (cos x < 1)%R.
-Proof.
- intros.
- generalize (COS_bound x); intros.
- generalize PI_RGT_0; intro pi_gt.
- destruct H1.
- assert (cos x <> 1)%R.
- {
- unfold not.
- intros.
- generalize (cos_eq_1_aux_pos x H3); intros.
- destruct H4.
- rewrite H4 in H0.
- rewrite Rmult_comm in H0.
- apply Rmult_lt_reg_r in H0; trivial.
- rewrite H4 in H.
- replace 0%R with (PI * 0)%R in H by lra.
- apply Rmult_lt_reg_l in H; trivial.
- assert (x0 = 1)%Z.
- {
- apply lt_IZR in H.
- apply lt_IZR in H0.
- lia.
- }
- rewrite H5 in H4.
- rewrite Rmult_1_r in H4.
- rewrite H4 in H3.
- generalize cos_PI; intros.
- lra.
- }
- lra.
- Qed.
-
-Lemma cos_eq_1_alt (x : R) :
- cos x = R1 ->
- exists (k:Z), x = (2 * PI * IZR(k))%R.
-Proof.
- intros Hx.
- assert (PI2_neq0: (2 * PI <> 0)%R).
- {
- generalize PI_neq0.
- lra.
- }
- destruct (euclidian_division x (2*PI) PI2_neq0) as (q & r & EQ & Hr & Hr').
- exists q.
- rewrite <- (Rplus_0_r (_*_)). subst.
- rewrite Rmult_comm.
- apply Rplus_eq_compat_l.
- rewrite cos_plus in Hx.
- assert (H : cos (IZR q * 2 * PI)%R = 1%R) by ( apply cos_eq_1_1; now exists q).
- rewrite <- Rmult_assoc in Hx.
- rewrite H, Rmult_1_l in Hx.
- rewrite sin_eq_0_1 in Hx.
- - rewrite Rmult_0_l, Rminus_0_r in Hx.
- rewrite Rabs_right in Hr'.
- + destruct Hr as [Hr | ->]; trivial.
- exfalso.
- generalize (cos_lt_1 r Hr Hr'); intros.
- lra.
- + generalize PI_RGT_0; lra.
- - exists (2*q)%Z.
- rewrite mult_IZR.
- lra.
- Qed.
-
-Lemma cos_eq_1_nneg (x : R) :
- cos x = R1 ->
- (0 <= x)%R ->
- exists (k:nat), x = (2 * PI * INR(k))%R.
-Proof.
- intros.
- generalize (cos_eq_1 x H); intros.
- destruct H1.
- rewrite H1 in H0.
- replace (0%R) with (2 * PI * 0)%R in H0 by lra.
- apply Rmult_le_reg_l in H0.
- - apply le_IZR in H0.
- exists (Z.abs_nat x0).
- rewrite H1.
- do 2 f_equal.
- destruct x0; simpl; trivial; try lia.
- now rewrite INR_IPR.
- - generalize PI_RGT_0; lra.
-Qed.
-
-Lemma sin_cos_eq x y:
- sin x = sin y /\ cos x = cos y ->
- exists (k:Z),
- x = (y + 2 * PI * IZR(k))%R.
-Proof.
- intros.
- generalize (cos_minus x y); intros.
- destruct H.
- rewrite H, H1 in H0.
- generalize (sin2_cos2 y); intros.
- rewrite Rplus_comm in H0.
- unfold Rsqr in H2.
- rewrite H2 in H0.
- apply cos_eq_1 in H0.
- destruct H0.
- exists x0.
- rewrite <- H0.
- lra.
-Qed.
-
-Lemma nth_root_eq j k n :
- j mod (S n) = k mod (S n) <->
- nth_root j (S n) = nth_root k (S n).
-Proof.
- split; intros.
- - now apply nth_root_mod.
- - unfold nth_root in H.
- replace (S n) with (n + 1) in H by lia.
- inversion H; clear H.
- generalize (sin_cos_eq (2 * PI * INR j / INR (n + 1))
- (2 * PI * INR k / INR (n + 1))); intros.
- destruct H.
- + split; trivial.
- + replace (2 * PI * INR k / INR (n + 1) + 2 * PI * IZR x)%R with
- (2 * PI * (INR k / INR (n+1) + IZR x))%R in H by lra.
- replace (2 * PI * INR j / INR (n + 1))%R with
- (2 * PI * (INR j / INR (n + 1)))%R in H by lra.
- apply (f_equal (fun r => (/ (2 * PI)) * r))%R in H.
- assert (2 * PI <> 0)%R.
- {
- generalize PI_neq0.
- lra.
- }
- rewrite <- Rmult_assoc in H.
- rewrite <- Rinv_l_sym, Rmult_1_l in H; trivial.
- rewrite <- Rmult_assoc in H.
- rewrite <- Rinv_l_sym, Rmult_1_l in H; trivial.
- clear H0 H1 H2.
- repeat rewrite plus_INR in H.
- simpl in H.
- assert (possn:(INR n + 1)%R <> 0%R).
- {
- generalize (pos_INR n); lra.
- }
- field_simplify in H; try lra.
- apply (f_equal (Rmult (INR n + 1))) in H.
- field_simplify in H; try lra.
- repeat rewrite INR_IZR_INZ in H.
- repeat rewrite <- mult_IZR in H.
- repeat rewrite <- plus_IZR in H.
- apply eq_IZR in H.
- apply Nat2Z.inj.
- repeat rewrite Nat2Z.inj_mod.
- rewrite H.
- transitivity ((Z.of_nat k + (x * (Z.of_nat (S n)))) mod Z.of_nat (S n))%Z.
- * f_equal.
- rewrite Nat2Z.inj_succ.
- lia.
- * now rewrite Zdiv.Z_mod_plus_full.
-Qed.
-
-Lemma nth_root_not_1 j n :
- j mod (S n) <> 0 ->
- nth_root j (S n) <> R1.
-Proof.
- unfold nth_root.
- intros.
- unfold RtoC.
- unfold not.
- intros.
- replace (S n) with (n + 1) in H0 by lia.
- inversion H0; clear H0.
- assert (xnneg :(0 <= 2 * PI * INR j / INR (n + 1))%R).
- {
- apply Rmult_le_pos.
- - generalize (pos_INR j); intros.
- apply Rmult_le_pos; trivial.
- generalize PI_RGT_0; intros.
- lra.
- - left.
- apply Rinv_0_lt_compat.
- apply lt_0_INR.
- lia.
- }
- apply cos_eq_1_nneg in H2; trivial.
- destruct H2.
- apply (f_equal (fun r => (r /(2 * PI))%R)) in H0.
- unfold Rdiv in H0.
- rewrite Rmult_comm in H0.
- assert ((2 * PI)%R <> R0).
- {
- generalize PI_neq0; intros.
- lra.
- }
- do 2 rewrite <- Rmult_assoc in H0.
- rewrite <- Rinv_l_sym in H0; trivial.
- rewrite Rmult_1_l in H0.
- symmetry in H0.
- rewrite Rmult_comm in H0.
- rewrite <- Rmult_assoc in H0.
- rewrite <- Rinv_l_sym in H0; trivial.
- rewrite Rmult_1_l in H0.
- replace (n+1) with (S n) in H0 by lia.
- apply (f_equal (fun r => (r * INR (S n))%R)) in H0.
- rewrite Rmult_assoc in H0.
- rewrite <- Rinv_l_sym in H0.
- - rewrite Rmult_1_r in H0.
- rewrite <- mult_INR in H0.
- apply INR_eq in H0.
- apply (f_equal (fun k => k mod (S n))) in H0.
- rewrite Nat.mod_mul in H0; try lia.
- - apply not_0_INR.
- lia.
- Qed.
-
-Lemma nth_root_1 j n :
- j mod (S n) = 0 ->
- nth_root j (S n) = R1.
-Proof.
- intros.
- rewrite (nth_root_mod j 0 n).
- - now rewrite nth_root_0.
- - rewrite H.
- rewrite Nat.mod_small; lia.
-Qed.
-
-Lemma Cinv_1_r :
- Cinv 1%R = 1%R.
-Proof.
- unfold Cinv.
- unfold RtoC.
- simpl.
- f_equal; field.
-Qed.
-
-Lemma Cpow_sub_r (c : C) (n m : nat):
- m <= n ->
- c <> R0 ->
- (c ^ (n - m))%C = (c ^ n / c ^ m)%C.
-Proof.
- intros.
- assert (Cmult (Cpow c (n - m)) (Cpow c m) = Cpow c n).
- {
- rewrite <- Cpow_add_r.
- f_equal.
- lia.
- }
- rewrite <- H1.
- unfold Cdiv.
- rewrite <- Cmult_assoc.
- rewrite Cinv_r.
- - now rewrite Cmult_1_r.
- - now apply Cpow_nz.
- Qed.
-
-Lemma nth_root_diff j k n :
- j <= k ->
- Cdiv (nth_root k (S n)) (nth_root j (S n)) = nth_root (k-j) (S n).
-Proof.
- intros.
- rewrite (prim_nth_root k _).
- rewrite (prim_nth_root j _).
- rewrite (prim_nth_root (k-j) _).
- rewrite Cpow_sub_r; trivial.
- apply nth_root_not_0.
-Qed.
-
-
-Instance list_Cplus_perm_proper : Proper (@Permutation _ ==> eq) list_Cplus.
-Proof.
- repeat red.
- apply Permutation_ind_bis; trivial; simpl; intros.
- - now rewrite H0.
- - rewrite H0; ring.
- - now rewrite H0, H2.
-Qed.
-
-Lemma C_telescope_mult (c : C) (n : nat) :
- (Cmult (c - R1) (list_Cplus (map (fun j => Cpow c j) (seq 0 (S n)))) =
- (Cpow c (S n) - 1%R))%C.
-Proof.
- induction n.
- - simpl; ring.
- - rewrite seq_S.
- simpl plus.
- rewrite map_app.
- unfold map at 2.
- rewrite <- Permutation_cons_append.
- unfold list_Cplus; fold list_Cplus.
- transitivity ((c - R1) * c ^ S n + (c - R1) * list_Cplus (map (fun j : nat => c ^ j) (seq 0 (S n))))%C; [ring |].
- rewrite IHn.
- simpl; ring.
-Qed.
-
-Lemma C_telescope_div (c : C) (n : nat) :
- c <> R1 ->
- list_Cplus (map (fun j => Cpow c j) (seq 0 (S n))) =
- Cdiv (Cpow c (S n) - 1%R) (c - R1).
-Proof.
- intros.
- generalize (C_telescope_mult c n); intros.
- rewrite <- H0.
- unfold Cdiv.
- rewrite Cmult_comm.
- rewrite Cmult_assoc.
- rewrite Cinv_l.
- - now rewrite Cmult_1_l.
- - unfold not.
- intros.
- unfold not in H.
- apply H.
- apply (f_equal (fun cc => Cplus cc (RtoC R1))) in H1.
- now ring_simplify in H1.
-Qed.
-
-Lemma C_telescope_pow_0 (c : C) (n : nat) :
- c <> R1 ->
- Cpow c (S n) = 1%R ->
- list_Cplus (map (fun j => Cpow c j) (seq 0 (S n))) = 0%R.
-Proof.
- intros.
- rewrite C_telescope_div; trivial.
- rewrite H0.
- field.
- simpl.
- unfold not; intros.
- apply (f_equal (fun cc => (cc + 1%R)%C)) in H1.
- now ring_simplify in H1.
-Qed.
-
-Lemma sum_nth_roots_0 n :
- list_Cplus (map (fun j => Cpow (nth_root 1 (S (S n))) j) (seq 0 (S (S n)))) = R0.
-Proof.
- apply C_telescope_pow_0.
- - apply nth_root_not_1.
- rewrite Nat.mod_1_l; lia.
- - now rewrite nth_root_npow.
- Qed.
-
-Lemma sum_nth_roots_0_gen k n :
- k mod (S (S n)) <> 0 ->
- list_Cplus (map (fun j => Cpow (nth_root k (S (S n))) j) (seq 0 (S (S n)))) = R0.
-Proof.
- intros.
- rewrite C_telescope_div.
- - rewrite nth_root_npow.
- unfold Cminus.
- rewrite Cplus_opp_r.
- unfold Cdiv.
- now rewrite Cmult_0_l.
- - now apply nth_root_not_1.
-Qed.
-
-Lemma pow_nth_root_prim n :
- Cpow (nth_root 1 (S n)) (S n) = R1.
-Proof.
- unfold nth_root.
- rewrite de_moive.
- replace (INR (S n) * (2 * PI * INR 1 / INR (S n)))%R with (2 * PI)%R.
- - now rewrite cos_2PI, sin_2PI.
- - replace (INR 1) with R1 by now unfold INR.
- field.
- apply S_INR_not_0.
- Qed.
-
-Lemma list_Cplus_app l1 l2 :
- list_Cplus (l1 ++ l2) = Cplus (list_Cplus l1) (list_Cplus l2).
-Proof.
- induction l1.
- - simpl.
- now rewrite Cplus_0_l.
- - simpl.
- rewrite IHl1.
- now rewrite Cplus_assoc.
- Qed.
-
-Lemma root_prod_1 j n :
- list_Cplus
- (map (fun k => Cmult (Cpow (nth_root j (S n)) k) (Cpow (Cinv (nth_root k (S n))) j))
- (seq 0 (S n))) = INR (S n).
-Proof.
- replace (map (fun k => Cmult (Cpow (nth_root j (S n)) k) (Cpow (Cinv (nth_root k (S n))) j))
- (seq 0 (S n))) with
- (map (fun k => RtoC R1) (seq 0 (S n))).
- - induction n.
- + simpl.
- now rewrite Cplus_0_r.
- + rewrite seq_S.
- rewrite map_app.
- rewrite list_Cplus_app.
- rewrite IHn.
- replace (S n) with (n + 1) by lia.
- replace (S (n + 1)) with (n + 2) by lia.
- simpl.
- do 2 rewrite plus_INR.
- simpl.
- unfold RtoC.
- rewrite Cplus_0_r.
- unfold Cplus, fst, snd.
- f_equal; lra.
- - apply map_ext.
- intros.
- rewrite Cpow_inv.
- + do 2 rewrite Cpow_nth_root.
- replace (j * a) with (a * j) by lia.
- rewrite Cinv_r.
- * now replace R1 with 1%R by lra.
- * apply nth_root_not_0.
- + apply nth_root_not_0.
- Qed.
-
-Lemma pow_nth_root j n :
- Cpow (nth_root j (S n)) (S n) = R1.
-Proof.
- rewrite prim_nth_root.
- rewrite <- Cpow_mult_r.
- replace (j * S n) with (S n * j) by lia.
- rewrite Cpow_mult_r.
- rewrite pow_nth_root_prim.
- now rewrite Cpow_1_l.
-Qed.
-
-Lemma nth_root_mul j k n :
- Cmult (nth_root j (S n)) (nth_root k (S n)) = nth_root (j + k) (S n).
-Proof.
- intros.
- rewrite (prim_nth_root k _).
- rewrite (prim_nth_root j _).
- rewrite (prim_nth_root (j + k) _).
- rewrite Cpow_add_r; trivial.
- Qed.
-
-Lemma nth_root_Sn n :
- nth_root (S n) (S n) = 1%R.
-Proof.
- rewrite prim_nth_root.
- now rewrite nth_root_npow.
-Qed.
-
-Lemma nth_root_inv j n :
- Cinv (nth_root j (S n)) = nth_root (S n - (j mod S n)) (S n).
-Proof.
- generalize (nth_root_diff (j mod S n) (S n) n); intros.
- rewrite <- H.
- - rewrite nth_root_Sn.
- unfold Cdiv.
- rewrite Cmult_1_l.
- f_equal.
- apply (nth_root_mod j (j mod S n) n).
- rewrite Nat.mod_mod; try lia.
- - assert (S n <> 0) by lia.
- generalize (Nat.mod_upper_bound j (S n) H0); intros.
- lia.
- Qed.
-
-Lemma nth_root_div j k n :
- Cdiv (nth_root j (S n)) (nth_root k (S n)) =
- nth_root (j + (S n - (k mod S n))) (S n).
-Proof.
- unfold Cdiv.
- rewrite nth_root_inv.
- apply nth_root_mul.
-Qed.
-
-Lemma nth_root_Cmod j n :
- Cmod (nth_root j (S n)) = 1%R.
-Proof.
- unfold Cmod, nth_root, fst, snd.
- rewrite Rplus_comm.
- rewrite <- sqrt_1.
- f_equal.
- do 2 rewrite <- Rsqr_pow2.
- now rewrite sin2_cos2.
-Qed.
-
-Lemma Cmod_Cconj (c : C) :
- Cmult c (Cconj c) = Rsqr (Cmod c).
-Proof.
- destruct c.
- unfold Cconj, Cmod, Cmult, fst, snd.
- rewrite Rsqr_sqrt.
- - unfold RtoC.
- f_equal; lra.
- - apply Rplus_le_le_0_compat; apply pow2_ge_0.
-Qed.
-
-Lemma nth_root_conj j n :
- Cconj (nth_root j (S n)) = Cinv (nth_root j (S n)).
-Proof.
- generalize (Cmod_Cconj (nth_root j (S n))); intros.
- rewrite nth_root_Cmod in H.
- rewrite Rsqr_1 in H.
- apply (f_equal (fun c => Cmult (/ nth_root j (S n)) c)) in H.
- rewrite Cmult_1_r in H.
- rewrite Cmult_assoc in H.
- rewrite Cinv_l in H.
- - now rewrite Cmult_1_l in H.
- - apply nth_root_not_0.
-Qed.
-
-Definition Ceval_poly (p : list C) (c : C) :=
- let cpows := map (fun j => Cpow c j) (seq 0 (length p)) in
- list_Cplus (map (fun '(a, b) => Cmult a b) (combine p cpows)).
-
-Definition Ceval_Rpoly (p : list R) (c : C) :=
- let cpows := map (fun j => Cpow c j) (seq 0 (length p)) in
- list_Cplus (map (fun '(a, b) => Cmult (RtoC a) b) (combine p cpows)).
-
-Fixpoint C_horner_eval (p : list C) (c : C) : C :=
- match p with
- | nil => R0
- | a :: p' => Cplus a (Cmult c (C_horner_eval p' c))
- end.
-
-Fixpoint C_horner_eval_Rpoly (p : list R) (c : C) : C :=
- match p with
- | nil => R0
- | a :: p' => Cplus a (Cmult c (C_horner_eval_Rpoly p' c))
- end.
-
-Lemma Ceval_horner_Rpoly (p : list R) (c : C) :
- Ceval_Rpoly p c = C_horner_eval_Rpoly p c.
-Proof.
- induction p.
- - now simpl.
- - unfold Ceval_Rpoly.
- simpl.
- rewrite Cmult_1_r.
- f_equal.
- rewrite <- IHp.
- unfold Ceval_Rpoly.
- now rewrite list_cplus_mult_seq_comm.
-Qed.
-
-Lemma pow2_S (j:nat) :
- exists (k : nat), 2^j = S k.
-Proof.
- exists (2^j-1).
- induction j.
- - now simpl.
- - simpl.
- rewrite IHj.
- lia.
-Qed.
-
-Lemma nth_root_half_pow_aux n :
- Cpow (nth_root (S n) (2 * (S n))) 2 = 1%R.
-Proof.
- replace (2 * (S n)) with (S (2 * n + 1)) by lia.
- rewrite Cpow_nth_root.
- do 2 replace (2 * (S n)) with (S (2 * n + 1)) by lia.
- now rewrite nth_root_Sn.
-Qed.
-
-Lemma pow2_inv x y : (x ^ 2)%R = y -> Rabs x = sqrt y.
-Proof.
- intros eqq1.
- apply (f_equal sqrt) in eqq1.
- destruct (Rle_dec 0 x).
- - intros.
- rewrite sqrt_pow2 in eqq1 by trivial.
- rewrite eqq1.
- rewrite Rabs_right; trivial.
- generalize (sqrt_pos y); lra.
- - assert (eqq1':sqrt ((-x) ^2) = sqrt y).
- {
- now rewrite <- Rsqr_pow2, <- Rsqr_neg, Rsqr_pow2.
- }
- rewrite sqrt_pow2 in eqq1' by lra.
- now rewrite Rabs_left by lra.
-Qed.
-
-Lemma Rabs_pm_l x y : Rabs x = y -> x = y \/ (- x)%R = y.
-Proof.
- unfold Rabs.
- destruct (Rcase_abs); [right|left]; lra.
-Qed.
-
-Lemma Rabs_pm_r x y : Rabs x = y -> x = y \/ x = (- y)%R.
-Proof.
- unfold Rabs.
- destruct (Rcase_abs); [right|left]; lra.
-Qed.
-
-Lemma Cpow_2 (c : C) :
- Cpow c 2 = 1%R -> c = 1%R \/ c = (-1)%R.
-Proof.
- unfold Cpow.
- rewrite Cmult_1_r.
- intros.
- destruct c.
- unfold Cmult, fst, snd, RtoC in H.
- injection H; intros; clear H.
- ring_simplify in H0.
- apply (f_equal (fun z => (/2 * z)%R)) in H0.
- do 2 rewrite <- Rmult_assoc in H0.
- rewrite <- Rinv_l_sym in H0; try lra.
- rewrite Rmult_1_l, Rmult_0_r in H0.
- apply Rmult_integral in H0.
- destruct H0; subst; ring_simplify in H1.
- - assert (0 <= r0 ^ 2)%R by apply pow2_ge_0.
- lra.
- - apply pow2_inv in H1.
- rewrite sqrt_1 in H1.
- apply Rabs_pm_r in H1.
- unfold RtoC.
- destruct H1; [left|right]; f_equal; lra.
-Qed.
-
-Lemma nth_root_half_pow n :
- nth_root (S n) (2 * (S n)) = (-1)%R.
-Proof.
- generalize (nth_root_half_pow_aux n); intros.
- apply Cpow_2 in H.
- destruct H; trivial.
- replace (2 * (S n)) with (S (2 * n + 1)) in H by lia.
- replace 1%R with R1 in H by lra.
- generalize (nth_root_not_1 (S n) (2*n+1)); intros.
- assert (S n mod S (2 * n + 1) <> 0).
- {
- rewrite Nat.mod_small; lia.
- }
- tauto.
-Qed.
-
-Lemma odd_roots_prim j n :
- Cpow (nth_root (2 * j + 1) (2 ^ (S n))) (2^n) = (-1)%R.
-Proof.
- generalize (pow2_S (S n)); intros.
- destruct H.
- rewrite H.
- rewrite Cpow_nth_root.
- rewrite <- H.
- assert ((2 ^ n * (2 * j + 1) mod (2 ^ S n)) =
- (2 ^ n mod (2 ^ S n))).
- {
- replace (2 ^n * (2 * j + 1)) with (2 ^ n + j*(2 * 2^n)) by lia.
- replace (2 ^ (S n)) with (2 * 2^n).
- - rewrite Nat.mod_add; try lia.
- assert (2^n <> 0).
- {
- apply Nat.pow_nonzero.
- lia.
- }
- lia.
- - simpl.
- lia.
- }
- rewrite H in H0.
- apply nth_root_mod in H0.
- rewrite <- H in H0.
- rewrite H0.
- generalize (pow2_S n); intros.
- destruct H1.
- simpl.
- replace (2 ^ n + (2 ^n + 0)) with (2 * 2^n) by lia.
- rewrite H1.
- now rewrite nth_root_half_pow.
-Qed.
-
-Definition odd_nth_roots (n : nat) :=
- (map (fun j => nth_root (2*j+1) (2 ^ (S n))) (seq 0 (2^n))).
-
-Definition V_odd_nth_roots (n : nat) : Vector C (2^n) :=
- fun j => nth_root (2 * (proj1_sig j) + 1) (2 ^ (S n)).
-
-Definition V_even_nth_roots (n : nat) : Vector C (2^n) :=
- fun j => nth_root (2 * (proj1_sig j)) (2 ^ (S n)).
-
-Definition V_nth_roots_half (n : nat) : Vector C (2^n) :=
- fun j => nth_root (proj1_sig j) (2 ^ (S n)).
-
-Definition odd_nth_roots_half (n : nat) :=
- (map (fun j => nth_root (2*j+1) (2 ^ (S (S n)))) (seq 0 (2^n))).
-
-Definition decode (p : list R) (n : nat) :=
- map (C_horner_eval_Rpoly p) (odd_nth_roots (S n)).
-
-Definition decode_eval (p : list R) (n : nat) :=
- map (Ceval_Rpoly p) (odd_nth_roots (S n)).
-
-Definition decode_Ceval (p : list C) (n : nat) :=
- map (Ceval_poly p) (odd_nth_roots (S n)).
-
-Definition decode_half (p : list R) (n : nat) :=
- map (C_horner_eval_Rpoly p) (odd_nth_roots_half n).
-
-Definition Cinner_prod (l1 l2 : list C) :=
- list_Cplus (map (fun '(a,b) => Cmult a b) (combine l1 l2)).
-
-Definition encode (cl : list C) (n : nat) :=
- let conjroots := map Cconj (odd_nth_roots (S n)) in
- map (fun c => Cdiv c (2^(S n))%R)
- (map (fun k => Cinner_prod cl (map (fun x => Cpow x k) conjroots))
- (seq 0 (2 ^(S n)))).
-
-Definition encode_half (cl : list C) (n : nat) :=
- let conjroots := map Cconj (odd_nth_roots_half n) in
- map (fun c => Rdiv c (2^(S n))%R)
- (map (fun k => (2 * (Re (Cinner_prod (firstn (2^n) cl) (map (fun x => Cpow x k) conjroots))))%R)
- (seq 0 (2 ^(S n)))).
-
-Lemma Im_div_real (c : C) (r : R) :
- r <> 0%R ->
- Im c = 0%R <-> Im (Cdiv c r) = 0%R.
-Proof.
- intros.
- destruct c.
- unfold Cdiv, Cinv, Cmult, Im, fst, snd.
- simpl.
- split; intros.
- - rewrite H0.
- now field.
- - field_simplify in H0; try easy.
- unfold pow in H0.
- field_simplify in H0; try easy.
- unfold Rdiv in H0.
- apply (f_equal (fun z => (z * r)%R)) in H0.
- rewrite Rmult_assoc, Rmult_0_l in H0.
- rewrite <- Rinv_l_sym in H0; trivial.
- lra.
-Qed.
-
-Lemma Cconj_im_0 (c : C) :
- Cconj c = c -> Im c = 0%R.
-Proof.
- destruct c.
- unfold Cconj; simpl.
- intros.
- injection H; intros.
- lra.
-Qed.
-
-Lemma map_inv_rev_even {A} n (l:list A) f (finv: forall x, f (f x) = x):
- length l = 2 * n ->
- map f l = rev l <->
- map f (firstn n l) = rev (skipn n l).
-Proof.
- intros llen.
- split; intros HH.
- - rewrite firstn_skipn_rev.
- rewrite llen, map_rev, <- skipn_map, map_rev.
- replace (2 * n - n) with n by lia.
- now rewrite HH, rev_involutive.
- - rewrite <- (firstn_skipn n l).
- rewrite map_app, rev_app_distr.
- f_equal; trivial.
- apply (f_equal (rev (A:=_))) in HH.
- repeat rewrite rev_involutive in HH.
- rewrite <- HH.
- rewrite map_rev, map_map.
- now erewrite map_ext; [rewrite map_id |].
-Qed.
-
-Lemma map_inv_rev_odd {A} n (l:list A) f (finv: forall x, f (f x) = x):
- length l = 2 * n + 1 ->
- map f l = rev l <->
- map f (firstn n l) = rev (skipn (S n) l) /\
- option_map f (nth_error l n) = (nth_error l n).
-Proof.
- intros llen.
- split; intros HH.
- - rewrite firstn_skipn_rev.
- rewrite llen, map_rev, <- skipn_map, map_rev.
- replace (2 * n + 1 - n) with (S n) by lia.
- split.
- + now rewrite HH, rev_involutive.
- + apply (f_equal (map f )) in HH.
- rewrite map_map in HH.
- erewrite map_ext in HH; [rewrite map_id in HH|]; trivial.
- apply (f_equal (fun x => nth_error x n)) in HH.
- rewrite map_rev, rev_nth_error in HH by (rewrite map_length, llen; lia).
- rewrite map_length in HH.
- replace (length l - S n)%nat with n in HH by lia.
- now rewrite nth_error_map in HH.
- - destruct HH as [HH1 HH2].
- rewrite <- (firstn_skipn (S n) l) at 2.
- rewrite <- (firstn_skipn n l) at 1.
- rewrite map_app, rev_app_distr.
- f_equal; trivial.
- case_eq (nth_error l n).
- + intros a ntha.
- rewrite (map_skipn_S_error _ _ _ ntha).
- apply (f_equal (@rev _)) in HH1.
- rewrite rev_involutive in HH1.
- rewrite <- HH1.
- rewrite map_cons, map_rev, map_map.
- erewrite map_ext; [rewrite map_id | auto].
- apply nth_error_split in ntha.
- destruct ntha as [l1 [l2 [? lenl1]]]; subst.
- repeat rewrite firstn_app.
- repeat rewrite rev_app_distr.
- replace ((length l1) - length l1) with 0 by lia.
- replace (S (length l1) - length l1) with 1 by lia.
- rewrite firstn_cons.
- repeat rewrite firstn_O.
- repeat rewrite firstn_all2 by lia.
- simpl; f_equal.
- rewrite nth_error_app2 in HH2 by trivial.
- replace (length l1 - length l1) with 0 in HH2 by lia.
- simpl in HH2.
- congruence.
- + intros HH.
- apply nth_error_None in HH.
- assert (n = 0) by lia.
- subst.
- destruct l; simpl in *; [lia |].
- destruct l; simpl in *; [| lia].
- congruence.
-Qed.
-
-Lemma conj_rev_even n cl :
- length cl = 2 * n ->
- map Cconj cl = rev cl <->
- map Cconj (firstn n cl) = firstn n (rev cl).
-Proof.
- intros llen.
- generalize (map_inv_rev_even n cl Cconj (Cconj_conj) llen).
- rewrite firstn_rev.
- now replace (length cl - n) with n by lia.
-Qed.
-
-Lemma conj_rev_odd cl n :
- length cl = 2 * n + 1 ->
- map Cconj cl = rev cl <->
- (map Cconj (firstn n cl) = rev (skipn (S n) cl) /\
- Im (nth n cl Ci) = 0%R).
-Proof.
- intros llen.
- rewrite (map_inv_rev_odd n cl Cconj (Cconj_conj) llen).
- split; intros [? HH]; split; trivial.
- - rewrite (nth_error_nth' cl Ci) in HH by lia.
- simpl in HH.
- invcs HH.
- apply Cconj_im_0.
- now rewrite Cconj_conj.
- - case_eq (nth_error cl n); simpl; trivial; intros.
- rewrite (nth_error_nth _ _ _ H0) in HH.
- destruct c; unfold Cconj; simpl in *.
- rewrite HH.
- do 2 f_equal.
- lra.
-Qed.
-
-Lemma conj_rev_half (cl_half:list C) :
- let cl := cl_half ++ rev (map Cconj cl_half) in
- map Cconj cl = rev cl.
-Proof.
- intros.
- pose (n := length cl_half).
- assert (length cl = 2*n).
- {
- unfold cl.
- rewrite app_length.
- rewrite rev_length.
- rewrite map_length.
- lia.
- }
- generalize (conj_rev_even n cl H); intros.
- apply H0.
- unfold cl.
- rewrite firstn_app.
- replace (length cl_half) with n by easy.
- replace (n - n) with 0 by lia.
- simpl.
- rewrite rev_app_distr.
- rewrite rev_involutive.
- rewrite firstn_app.
- replace (n - length (map Cconj cl_half)) with 0.
- - rewrite map_app.
- simpl.
- f_equal.
- now rewrite firstn_map.
- - rewrite map_length.
- lia.
- Qed.
-
-Lemma conj_rev_half_conv n (cl:list C) :
- length cl = 2*n ->
- map Cconj cl = rev cl ->
- let cl_half := firstn n cl in
- cl = cl_half ++ rev (map Cconj cl_half) .
-Proof.
- intros.
- generalize (conj_rev_even n cl H); intros.
- apply H1 in H0.
- unfold cl_half.
- rewrite H0.
- rewrite firstn_rev.
- rewrite rev_involutive.
- replace (length cl - n) with n by lia.
- now rewrite firstn_skipn.
-Qed.
-
-Lemma list_Cplus_rev (cl : list C) :
- list_Cplus (rev cl) = list_Cplus cl.
-Proof.
- apply list_Cplus_perm_proper.
- apply Permutation_sym.
- apply Permutation_rev.
-Qed.
-
-Lemma list_Cplus_conj (cl : list C) :
- list_Cplus (map Cconj cl) = Cconj (list_Cplus cl).
-Proof.
- induction cl.
- - unfold Cconj, fst, snd; simpl.
- unfold RtoC; f_equal; lra.
- - simpl.
- rewrite Cplus_conj.
- now f_equal.
-Qed.
-
-Lemma conj_rev_half_sum n (cl : list C) :
- length cl = 2*n ->
- map Cconj cl = rev cl ->
- let sum_cl_half := list_Cplus (firstn n cl) in
- list_Cplus cl = (sum_cl_half + Cconj sum_cl_half)%C.
-Proof.
- intros.
- rewrite (conj_rev_half_conv n cl); trivial.
- rewrite list_Cplus_app.
- unfold sum_cl_half.
- f_equal.
- rewrite list_Cplus_rev.
- now rewrite list_Cplus_conj.
-Qed.
-
-Lemma Cplus_conj (c : C) :
- Cplus c (Cconj c) = (2 * Re c)%R.
-Proof.
- destruct c.
- unfold Cplus, Cconj, RtoC, Re, fst, snd.
- f_equal; lra.
-Qed.
-
-Lemma conj_rev_half_sum_alt n (cl : list C) :
- length cl = 2*n ->
- map Cconj cl = rev cl ->
- let sum_cl_half := list_Cplus (firstn n cl) in
- list_Cplus cl = (2*(Re sum_cl_half))%R.
-Proof.
- intros.
- rewrite (conj_rev_half_sum n); trivial.
- now rewrite Cplus_conj.
-Qed.
-
-Lemma conj_rev_rev (cl : list C) :
- map Cconj cl = rev cl ->
- cl = map Cconj (rev cl).
-Proof.
- intros.
- apply (f_equal (fun l => rev l)) in H.
- rewrite rev_involutive in H.
- rewrite <- H at 1.
- now rewrite map_rev.
-Qed.
-
-Lemma list_Cplus_conj_rev_0 (cl : list C):
- length cl < 2 ->
- map Cconj cl = rev cl ->
- Im (list_Cplus cl) = 0%R.
-Proof.
- intros.
- pose (n := length cl).
- destruct cl.
- - now simpl.
- - destruct cl.
- + simpl.
- simpl in H0.
- inversion H0.
- rewrite H2.
- apply Cconj_im_0 in H2.
- now rewrite Rplus_0_r.
- + simpl in H.
- lia.
-Qed.
-
-Lemma list_Cplus_conj_rev_recur (n : nat) :
- (forall (cl : list C),
- length cl = n ->
- map Cconj cl = rev cl ->
- Im (list_Cplus cl) = 0%R) ->
- forall (cl : list C),
- length cl = n + 2 ->
- map Cconj cl = rev cl ->
- Im (list_Cplus cl) = 0%R.
-Proof.
- intros.
- destruct cl; trivial.
- simpl in H0.
- assert (lcl: length cl = n+1) by lia.
- assert(exists (c2 : C) (cl2 : list C),
- cl = cl2 ++ (c2 :: nil)).
- {
- assert (cl <> nil).
- {
- unfold not; intros.
- rewrite H2 in H0.
- simpl in H0.
- lia.
- }
- destruct (exists_last H2) as [? [??]].
- exists x0.
- now exists x.
- }
- destruct H2 as [c2 [cl2 ?]].
- rewrite H2.
- specialize (H cl2).
- simpl.
- rewrite list_Cplus_app.
- rewrite im_plus.
- rewrite H2 in H1.
- simpl in H1.
- rewrite map_app in H1.
- rewrite rev_app_distr in H1.
- simpl in H1.
- inversion H1.
- rewrite Cconj_conj in H5.
- apply app_inv_tail in H5.
- rewrite H; trivial.
- - simpl.
- lra.
- - rewrite H2 in H0.
- rewrite app_length in H0.
- simpl in H0.
- lia.
- Qed.
-
-Lemma pair_induction (P : nat -> Prop) :
- P 0 -> P 1 ->
- (forall n, P n -> P (S n) -> P (S (S n))) ->
- forall n, P n.
-Proof.
- intros H0 H1 Hstep n.
- enough (P n /\ P (S n)) by easy.
- induction n; intuition.
-Qed.
-
-Lemma list_Cplus_conj_rev (cl : list C):
- map Cconj cl = rev cl ->
- Im (list_Cplus cl) = 0%R.
-Proof.
- intros HH.
- apply (list_cons_app_hyp
- (fun cl => Im (list_Cplus cl) = 0%R)
- (fun x y => Cconj x = y)).
- - trivial.
- - simpl; intros.
- rewrite Cconj_im_0; trivial.
- lra.
- - intros.
- simpl.
- rewrite <- Permutation_cons_append; simpl.
- unfold Im in *.
- rewrite H1.
- rewrite <- H.
- simpl; lra.
- - rewrite <- HH.
- apply Forall2_map_Forall.
- rewrite Forall_forall; trivial.
-Qed.
-
-Lemma combine_app {T} (cl1 cl2 cl1' cl2' : list T) :
- length cl1 = length cl2 ->
- length cl1' = length cl2' ->
- combine (cl1 ++ cl1') (cl2 ++ cl2') = combine cl1 cl2 ++ combine cl1' cl2'.
-Proof.
- revert cl2.
- induction cl1; intros; simpl; trivial.
- - simpl in H.
- symmetry in H.
- apply length_zero_iff_nil in H.
- rewrite H; now simpl.
- - destruct cl2; simpl; [now simpl in H|].
- rewrite IHcl1; trivial.
- simpl in H.
- lia.
- Qed.
-
-Lemma combine_rev {T} (cl1 cl2 : list T) :
- length cl1 = length cl2 ->
- combine (rev cl1) (rev cl2) = rev (combine cl1 cl2).
-Proof.
- revert cl2.
- induction cl1; intros ; simpl; trivial.
- destruct cl2; simpl.
- - now rewrite combine_nil.
- - simpl in H.
- injection H; intros.
- rewrite <- IHcl1; trivial.
- rewrite combine_app.
- + now simpl.
- + now do 2 rewrite rev_length.
- + now simpl.
- Qed.
-
-Lemma Cmult_combine_rev (cl1 cl2 : list C) :
- length cl1 = length cl2 ->
- map (fun '(a, b) => (a * b)%C) (combine (rev cl1) (rev cl2)) =
- rev (map (fun '(a, b) => (a * b)%C) (combine cl1 cl2)).
-Proof.
- intros.
- rewrite <- map_rev.
- f_equal.
- now apply combine_rev.
-Qed.
-
-Lemma Cmult_combine_conv (cl1 cl2 : list C) :
- map (fun '(a, b) => (a * b)%C) (combine (map Cconj cl1) (map Cconj cl2)) =
- map Cconj (map (fun '(a, b) => (a * b)%C) (combine cl1 cl2)).
-Proof.
- rewrite combine_map.
- do 2 rewrite map_map.
- apply map_ext.
- intros.
- destruct a.
- now rewrite Cmult_conj.
-Qed.
-
-Lemma map_mult_conj_rev (cl1 cl2 : list C):
- map Cconj cl1 = rev cl1 ->
- map Cconj cl2 = rev cl2 ->
- length cl1 = length cl2 ->
- let cl := map (fun '(a,b) => Cmult a b) (combine cl1 cl2) in
- map Cconj cl = rev cl.
-Proof.
- intros.
- assert (combine (map Cconj cl1) (map Cconj cl2) =
- combine (rev cl1) (rev cl2)).
- {
- now rewrite H, H0.
- }
- apply (f_equal (fun ll => map (fun '(a, b) => (a * b)%C) ll)) in H2.
- now rewrite Cmult_combine_rev, Cmult_combine_conv in H2.
-Qed.
-
-Lemma map_pow_conj_rev (cl : list C) (n : nat) :
- map Cconj cl = rev cl ->
- map Cconj (map (fun c => Cpow c n) cl) =
- rev (map (fun c => Cpow c n) cl).
-Proof.
- intros.
- apply (f_equal (fun ll => map (fun cc => Cpow cc n) ll)) in H.
- rewrite map_map in H.
- rewrite map_rev in H.
- rewrite <- H.
- rewrite map_map.
- apply map_ext.
- intros.
- now rewrite Cpow_conj.
-Qed.
-
-Lemma map_conj_conj_rev (cl : list C) :
- map Cconj cl = rev cl ->
- map Cconj (map Cconj cl) =
- rev (map Cconj cl).
-Proof.
- intros.
- apply (f_equal (fun ll => map Cconj ll)) in H.
- rewrite map_map in H.
- rewrite map_rev in H.
- rewrite <- H.
- rewrite map_map.
- now apply map_ext.
-Qed.
-
-Lemma odd_nth_roots_conj_rev n :
- let cl := map Cconj (odd_nth_roots (S n)) in
- map Cconj cl = rev cl.
-Proof.
- simpl.
- apply map_conj_conj_rev.
- unfold odd_nth_roots.
- rewrite <- map_rev.
- rewrite rev_seq.
- do 2 rewrite map_map.
- apply map_ext_in; intros.
- rewrite plus_0_r.
- apply in_seq in H.
- destruct H as [_ alt].
- rewrite plus_0_l in alt.
- destruct (pow2_S (S (S n))); intros.
- rewrite H.
- rewrite nth_root_conj.
- rewrite nth_root_inv.
- f_equal.
- rewrite <- H.
- replace (2 ^ (S (S n))) with (2 * 2 ^ (S n)) by (simpl; lia).
- rewrite Nat.mod_small; lia.
-Qed.
-
-Lemma Cinner_prod_conj_rev (cl1 cl2 : list C) :
- length (cl1) = length cl2 ->
- map Cconj cl1 = rev cl1 ->
- map Cconj cl2 = rev cl2 ->
- Im (Cinner_prod cl1 cl2) = 0%R.
-Proof.
- intros.
- unfold Cinner_prod.
- apply list_Cplus_conj_rev.
- apply map_mult_conj_rev; trivial.
-Qed.
-
-Lemma encode_real (cl : list C) (n : nat):
- map Cconj cl = rev cl ->
- length cl = length (odd_nth_roots (S n)) ->
- forall (x : C),
- In x (encode cl n) -> Im x = 0%R.
-Proof.
- intros.
- unfold encode in H1.
- apply in_map_iff in H1.
- destruct H1 as [? [? ?]].
- apply in_map_iff in H2.
- destruct H2 as [? [? ?]].
- assert (Im x0 = 0%R).
- {
- rewrite <- H2.
- apply Cinner_prod_conj_rev; trivial.
- - now do 2 rewrite map_length.
- - apply map_pow_conj_rev.
- apply odd_nth_roots_conj_rev.
- }
- rewrite <- H1.
- apply Im_div_real; trivial.
- apply pow_nonzero.
- lra.
-Qed.
-
-Lemma Re_Im (c : C) :
- Im c = 0%R <-> c = RtoC (Re c).
-Proof.
- destruct c.
- unfold RtoC, Im, Re, fst, snd.
- split; intros.
- - now rewrite H.
- - now inversion H.
-Qed.
-
-Lemma clist_real (cl : list C) :
- (forall (x : C),
- In x cl -> Im x = 0%R) ->
- cl = map RtoC (map Re cl).
-Proof.
- intros.
- rewrite <- List.map_id at 1.
- rewrite map_map.
- apply map_ext_in.
- intros.
- specialize (H a H0).
- now apply Re_Im.
-Qed.
-
-Lemma encode_real_alt (cl : list C) (n : nat):
- map Cconj cl = rev cl ->
- length cl = length (odd_nth_roots (S n)) ->
- encode cl n = map RtoC (map Re (encode cl n)).
-Proof.
- intros.
- apply clist_real.
- now apply encode_real.
-Qed.
-
-Lemma Re_Cmult_R_C (r : R) (c : C) :
- Re (Cmult r c) = Rmult r (Re c).
-Proof.
- destruct c; simpl.
- ring.
-Qed.
-
-Lemma Re_Cmult_list_Cplus (r : R) (cl : list C) :
- RtoC (Re (Cmult r (list_Cplus cl))) =
- Cmult r (list_Cplus (map RtoC (map Re cl))).
-Proof.
- intros.
- induction cl.
- - simpl.
- unfold Cmult, Re, RtoC, fst, snd.
- f_equal; lra.
- - simpl.
- rewrite Rmult_0_l.
- rewrite Rminus_0_r.
- rewrite Cmult_plus_distr_l.
- rewrite <- IHcl.
- rewrite <- RtoC_mult.
- rewrite <- RtoC_plus.
- f_equal.
- rewrite Rmult_plus_distr_l.
- f_equal.
- now rewrite Re_Cmult_R_C.
-Qed.
-
-Lemma map_commute {A} (l : list A) (f g : A -> A) :
- (forall a, f (g a) = g (f a)) ->
- map f (map g l) = map g (map f l).
-Proof.
- intros.
- do 2 rewrite map_map.
- apply map_ext.
- apply H.
-Qed.
-
-Lemma double_Re_commute (c : C) :
- (2 * (Re c))%R = Re (2%R * c)%C.
-Proof.
- destruct c.
- unfold RtoC, Re, fst, snd.
- simpl.
- now ring.
-Qed.
-
-Lemma div_Re_commute (c : C) (r : R) :
- (r <> 0)%R ->
- (Re c / r)%R = Re (c / r)%C.
-Proof.
- intros.
- destruct c.
- unfold RtoC, Re, fst, snd.
- simpl.
- now field.
-Qed.
-
-Lemma encode_half_correct (cl : list C) (n : nat):
- length cl = 2^S n ->
- map Cconj cl = rev cl ->
- encode_half cl n = map Re (encode cl n).
-Proof.
- intros.
- unfold encode_half, encode.
- rewrite map_map.
- rewrite map_map.
- rewrite map_map.
- apply map_ext.
- intros.
- rewrite double_Re_commute.
- rewrite <- div_Re_commute; [| apply pow_nonzero; lra].
- f_equal.
- unfold Cinner_prod.
- symmetry.
- rewrite (conj_rev_half_sum_alt (2 ^ n)).
- - rewrite double_Re_commute.
- rewrite re_RtoC.
- do 3 f_equal.
- rewrite firstn_map.
- f_equal.
- rewrite combine_firstn.
- f_equal.
- rewrite firstn_map.
- f_equal.
- unfold odd_nth_roots, odd_nth_roots_half.
- do 2 rewrite firstn_map.
- do 2 f_equal.
- apply firstn_seq.
- simpl; lia.
- - rewrite map_length.
- rewrite combine_length.
- do 2 rewrite map_length.
- unfold odd_nth_roots.
- rewrite map_length.
- rewrite seq_length.
- rewrite H.
- apply Nat.min_id.
- - apply map_mult_conj_rev; trivial.
- + apply map_pow_conj_rev.
- apply odd_nth_roots_conj_rev.
- + rewrite map_length.
- rewrite map_length.
- unfold odd_nth_roots.
- rewrite map_length.
- now rewrite seq_length.
-Qed.
-
-Lemma encode_half_correct_alt (cl : list C) (n : nat):
- length cl = 2^S n ->
- map Cconj cl = rev cl ->
- map RtoC (encode_half cl n) = encode cl n.
-Proof.
- intros.
- generalize (encode_half_correct cl n H H0); intros.
- rewrite (encode_real_alt cl n H0).
- - now f_equal.
- - unfold odd_nth_roots.
- now rewrite map_length, seq_length.
-Qed.
-
-Definition peval_mat (l : list C) :=
- let n := length l in
- map (fun c => map (fun k => Cpow c k) (seq 0 n)) l.
-
-Definition V_peval_mat {n} (roots : Vector C n) : Matrix C n n :=
- (fun n1 n2 => Cpow (roots n1) (proj1_sig n2)).
-
-Definition conj_mat (m : list (list C)) :=
- map (fun cl => map Cconj cl) m.
-
-Definition V_conj_mat {n1 n2} (m : Matrix C n1 n2) :=
- fun n1' n2' => Cconj (m n1' n2').
-
-Definition V_inv_mat {n1 n2} (m : Matrix C n1 n2) :=
- fun n1' n2' => Cinv (m n1' n2').
-
-Definition transpose_mat (m : list (list C)) :=
- let n := length m in
- map (fun k =>
- map (fun cl => nth k cl (RtoC 0%R)) m)
- (seq 0 n).
-
-Definition mat_vec_mult (m : list (list C)) (v : list C) :=
- map (fun ml => Cinner_prod ml v) m.
-
-Definition vector_sum {n} (v : Vector C n) :=
- vector_fold_right Cplus 0%R v.
-
-Definition V_inner_prod {n} (v1 v2 : Vector C n) :=
- vector_sum (fun n' => Cmult (v1 n') (v2 n')).
-
-Definition V_mat_vec_mult {n1 n2} (m : Matrix C n1 n2) (v : Vector C n2) :=
- fun n' => V_inner_prod (m n') v.
-
-Definition vec_mat_mult (v : list C) (m : list (list C)) :=
- map (fun cl => Cinner_prod v cl) (transpose_mat m).
-
-Definition V_vec_mat_mult {n1 n2} (v : Vector C n1) (m : Matrix C n1 n2) :=
- fun n' => V_inner_prod v ((transpose m) n').
-
-Definition mat_mat_mult (m1 m2 : list (list C)) :=
- map (fun v => mat_vec_mult m1 v) (transpose_mat m2).
-
-Definition V_mat_mat_mult {n1 n2 n3} (m1: Matrix C n1 n2) (m2 : Matrix C n2 n3) :=
- (fun n1' n3' => V_inner_prod (m1 n1') ((transpose m2) n3')).
-
-(*
-Lemma mmv_mult_assoc_row (m2: list (list C)) (r1 v : list C) :
- length r1 = length m2 ->
- (forall r2, In r2 m2 -> length r2 = length v) ->
- Cinner_prod r1 (mat_vec_mult m2 v) =
- Cinner_prod (vec_mat_mult r1 m2) v.
-Proof.
- intros.
- unfold mat_vec_mult, vec_mat_mult.
- unfold Cinner_prod.
- f_equal.
-Admitted.
-*)
-
-Definition Vscale {n} (r : C) (v : Vector C n) :=
- fun n' => Cmult r (v n').
-
-Definition Vscale_r {n} (r : C) (v : Vector C n) :=
- fun n' => Cmult (v n') r.
-
-Lemma vector_fold_right_Cplus_0 ( v : Vector C 0) :
- vector_fold_right Cplus 0%R v = 0%R.
-Proof.
- unfold vector_fold_right, vector_fold_right_dep, vector_fold_right_bounded_dep.
- now simpl.
-Qed.
-
-Lemma vector_sum_scale {n} (c : C) (v : Vector C n) :
- Cmult c (vector_sum v) = vector_sum (Vscale c v).
-Proof.
- unfold vector_sum, Vscale.
- induction n.
- - do 2 rewrite vector_fold_right_Cplus_0.
- now rewrite Cmult_0_r.
- - rewrite vector_fold_right_Sn.
- rewrite Cmult_plus_distr_l.
- rewrite IHn.
- now rewrite vector_fold_right_Sn.
-Qed.
-
-Lemma vector_sum_scale_r {n} (v : Vector C n) (c : C) :
- Cmult (vector_sum v) c = vector_sum (Vscale_r c v).
-Proof.
- unfold vector_sum, Vscale.
- induction n.
- - unfold vector_fold_right, vector_fold_right_dep, vector_fold_right_bounded_dep.
- simpl.
- now rewrite Cmult_0_l.
- - rewrite vector_fold_right_Sn.
- rewrite Cmult_plus_distr_r.
- rewrite IHn.
- now rewrite vector_fold_right_Sn.
-Qed.
-
-Definition vmap' {A B} {n} (f:A->B) (v:Vector A n) : Vector B n
- := fun n' => f (v n').
-
-Definition mmap' {A B} {m n} (f:A->B) (mat:Matrix A m n) : Matrix B m n
- := fun n1' n2' => f (mat n1' n2').
-
-Lemma vector_sum_const {n} (c : C) :
- vector_sum (ConstVector n c) = Cmult (RtoC (INR n)) c.
-Proof.
- unfold vector_sum, ConstVector.
- induction n.
- - rewrite vector_fold_right_Cplus_0.
- now rewrite Cmult_0_l.
- - rewrite vector_fold_right_Sn.
- unfold vlast, vdrop_last.
- rewrite S_INR.
- replace ((INR n + 1)%R * c)%C with
- (c + (INR n) * c)%C.
- + f_equal.
- rewrite <- IHn.
- f_equal.
- apply vec_eq_eq; intros x.
- now destruct x.
- + unfold RtoC.
- destruct c.
- unfold Cplus, Cmult, fst, snd.
- f_equal; lra.
- Qed.
-
-Lemma vector_sum_sum_transpose {n1 n2} (m : Matrix C n1 n2) :
- vector_sum (fun n' => vector_sum (m n')) =
- vector_sum (fun n'' => vector_sum ((transpose m) n'')).
-Proof.
- unfold transpose.
- unfold vector_sum.
- induction n1.
- - induction n2.
- + now do 2 rewrite vector_fold_right_Cplus_0.
- + rewrite vector_fold_right_Cplus_0.
- generalize (@vector_sum_const (S n2) (RtoC 0%R)); intros.
- rewrite Cmult_0_r in H.
- rewrite <- H at 1.
- unfold vector_sum.
- apply vector_fold_right_ext.
- intros ?.
- destruct i.
- rewrite vector_fold_right_Cplus_0.
- now unfold ConstVector.
- - induction n2.
- + rewrite vector_fold_right_Cplus_0.
- generalize (@vector_sum_const (S n1) (RtoC 0%R)); intros.
- rewrite Cmult_0_r in H.
- rewrite <- H at 2.
- unfold vector_sum.
- apply vector_fold_right_ext.
- intros ?.
- rewrite vector_fold_right_Cplus_0.
- now unfold ConstVector.
- + do 2 rewrite vector_fold_right_Sn.
- unfold vlast.
-
- admit.
-Admitted.
-
-
-Lemma V_mmv_mult_assoc {n1 n2 n3}
- (m1 : Matrix C n1 n2)
- (m2: Matrix C n2 n3)
- (v : Vector C n3) :
- V_mat_vec_mult m1 (V_mat_vec_mult m2 v) =
- V_mat_vec_mult (V_mat_mat_mult m1 m2) v.
-Proof.
- intros.
- unfold V_mat_vec_mult, V_mat_mat_mult.
- unfold V_inner_prod, transpose.
- apply vec_eq_eq; intros ?.
- Admitted.
-
-
-(*
-Lemma mmv_mult_assoc (m1 m2: list (list C)) (v : list C) :
- (forall r1, In r1 m1 -> length r1 = length m2) ->
- (forall r2, In r2 m2 -> length r2 = length v) ->
- mat_vec_mult m1 (mat_vec_mult m2 v) =
- mat_vec_mult (mat_mat_mult m1 m2) v.
-Proof.
- intros.
- unfold mat_mat_mult.
- unfold mat_vec_mult at 1.
- replace (map (fun ml => Cinner_prod ml (mat_vec_mult m2 v)) m1) with
- (map (fun r1 =>
- Cinner_prod (vec_mat_mult r1 m2) v) m1).
- - unfold vec_mat_mult, mat_vec_mult.
- unfold Cinner_prod.
- rewrite map_map.
- admit.
- - apply map_ext_in.
- intros.
- rewrite mmv_mult_assoc_row; trivial.
- now specialize (H a H1).
-
-Admitted.
-*)
-
-Lemma map_Cmult_combine_comm l1 l2 :
- map (fun '(a0, b) => (a0 * b)%C) (combine l1 l2) =
- map (fun '(a0, b) => (a0 * b)%C) (combine l2 l1).
-Proof.
- rewrite combine_swap.
- rewrite map_map.
- apply map_ext.
- intros.
- destruct a.
- unfold fst, snd.
- now rewrite Cmult_comm.
-Qed.
-
-Lemma peval_mat_decode_Ceval (p : list C) n :
- length p = length (odd_nth_roots (S n)) ->
- decode_Ceval p n =
- mat_vec_mult (peval_mat (odd_nth_roots (S n))) p.
-Proof.
- intros.
- unfold decode_Ceval, mat_vec_mult, peval_mat, Ceval_poly.
- rewrite map_map.
- apply map_ext.
- intros.
- unfold Cinner_prod.
- apply list_Cplus_perm_proper, refl_refl.
- rewrite map_Cmult_combine_comm.
- now rewrite <- H.
-Qed.
-
-Lemma map_Cinv_l (lc : list C) (c : C) :
- c <> 0%R ->
- map (fun x => (x / c * c)%C) lc = lc.
-Proof.
- intros.
- replace lc with (map (fun (x : C) => x) lc) at 2.
- - apply map_ext.
- intros.
- unfold Cdiv.
- rewrite <- Cmult_assoc.
- rewrite Cinv_l; trivial.
- now rewrite Cmult_1_r.
- - now rewrite map_id.
-Qed.
-
-Lemma nth_map_seq {A} (f : nat -> A) (d : A) (a m : nat) :
- In a (seq 0 m) ->
- nth a (map f (seq 0 m)) d = f a.
-Proof.
- intros.
- induction m.
- - now simpl in H.
- - rewrite seq_S.
- rewrite map_app.
- destruct (lt_dec a m).
- + simpl.
- rewrite app_nth1.
- * rewrite IHm; trivial.
- rewrite in_seq; lia.
- * now rewrite map_length, seq_length.
- + rewrite in_seq in H.
- assert (a = m) by lia.
- rewrite H0.
- simpl.
- rewrite app_nth2.
- * rewrite map_length, seq_length.
- now replace (m - m) with 0 by lia.
- * rewrite map_length, seq_length; lia.
-Qed.
-
-Lemma conj_trans_mat_encode (cl : list C) n :
- length cl = length (odd_nth_roots (S n)) ->
- map (fun c => Cmult c (2^(S n))%R) (encode cl n) =
- mat_vec_mult (conj_mat (transpose_mat (peval_mat (odd_nth_roots (S n)))))
- cl.
- Proof.
- intros.
- unfold encode.
- rewrite map_map.
- assert (length (odd_nth_roots (S n)) = 2^(S n)).
- {
- unfold odd_nth_roots.
- now rewrite map_length, seq_length.
- }
- rewrite map_Cinv_l.
- - unfold mat_vec_mult, conj_mat.
- rewrite map_map.
- unfold transpose_mat.
- rewrite map_map.
- unfold peval_mat.
- rewrite map_length.
- rewrite H0.
- apply map_ext_in.
- intros.
- unfold Cinner_prod.
- apply list_Cplus_perm_proper, refl_refl.
- rewrite map_Cmult_combine_comm.
- do 2 f_equal.
- do 3 rewrite map_map.
- apply map_ext.
- intros.
- rewrite <- Cpow_conj.
- f_equal.
- now rewrite nth_map_seq.
- - unfold RtoC.
- unfold not; intros.
- replace (S n) with (n + 1) in H1 by lia.
- injection H1; intros.
- generalize (pow_nonzero 2 (n+1)); intros.
- cut_to H3; lra.
- Qed.
-
-(*
-Lemma encode_decode_eval (cl : list C) (n : nat):
- map Cconj cl = rev cl ->
- length cl = length (odd_nth_roots (S n)) ->
- decode_eval (map Re (encode cl n)) n = cl.
-Proof.
- intros.
- unfold encode, decode_eval.
- rewrite map_map.
- rewrite map_map.
- unfold Ceval_Rpoly.
- rewrite map_length.
- rewrite seq_length.
- etransitivity.
- - apply map_ext; intros.
- apply list_Cplus_perm_proper.
- rewrite combine_map.
- rewrite combine_self.
- repeat rewrite map_map.
- reflexivity.
- - apply nth_error_eqs; intros.
- destruct (lt_dec i (length cl)).
- + rewrite nth_error_map.
- unfold odd_nth_roots at 2.
- rewrite nth_error_map.
- rewrite seq_nth_error.
- * unfold option_map.
- admit.
- * unfold odd_nth_roots in H0.
- rewrite map_length, seq_length in H0.
- congruence.
- + assert (length cl <= i) by lia.
- apply nth_error_None in H1.
- rewrite H1.
- apply nth_error_None.
- rewrite map_length, <- H0; lia.
-Admitted.
-
-Lemma encode_decode (cl : list C) (n : nat):
- map Cconj cl = rev cl ->
- length cl = length (odd_nth_roots (S n)) ->
- decode (map Re (encode cl n)) n = cl.
-Proof.
- intros.
- rewrite <- encode_decode_eval with (n := n); trivial.
- unfold decode, decode_eval.
- apply map_ext.
- intros.
- now rewrite Ceval_horner_Rpoly.
-Qed.
- *)
-
-
-(* claim (nxn vandermonde on odd roots) x conjugate transpose = n * I. *)
-
-Lemma mult_conj_root j n :
- Cmult (nth_root j (S n)) (Cconj (nth_root j (S n))) = 1%R.
-Proof.
- rewrite nth_root_conj.
- rewrite Cinv_r; trivial.
- apply nth_root_not_0.
-Qed.
-
-Lemma nth_root_half n :
- nth_root (2 ^n) (2 ^ (S n)) = (-1)%R.
-Proof.
- destruct (pow2_S (S n)).
- generalize (odd_roots_prim 0 n); intros.
- replace (2 * 0 +1) with 1 in H by lia.
- rewrite H in H0.
- rewrite Cpow_nth_root in H0.
- rewrite <- H in H0.
- now replace (2^n * (2 * 0 + 1)) with (2 ^ n) in H0 by lia.
-Qed.
-
-Lemma nth_root_opp j n :
- (nth_root j (2 ^ (S n)) + nth_root (j + 2^n) (2 ^ (S n)) = 0%R)%C.
-Proof.
- destruct (pow2_S (S n)).
- rewrite H.
- rewrite <- nth_root_mul.
- rewrite <- H.
- rewrite nth_root_half.
- ring.
-Qed.
-
-Hint Rewrite Nat2Z.inj_mul : of_nat_re.
-Hint Rewrite Nat2Z.inj_add : of_nat_re.
-Hint Rewrite Nat2Z.inj_sub : of_nat_re.
-Hint Rewrite Nat2Z.inj_mod : of_nat_re.
-Hint Rewrite Nat2Z.inj_pow : of_nat_re.
-Hint Rewrite Nat2Z.inj_div : of_nat_re.
-Hint Rewrite Nat2Z.inj_0 : of_nat_re.
-
-Lemma of_nat_1 : Z.of_nat 1 = 1%Z.
-Proof.
- lia.
-Qed.
-
-Lemma of_nat_2 : Z.of_nat 2 = 2%Z.
-Proof.
- lia.
-Qed.
-
-Hint Rewrite of_nat_1 : of_nat_re.
-Hint Rewrite of_nat_2 : of_nat_re.
-
-
-Lemma mod_odd_even x y :
- y <> 0 ->
- Nat.Odd x -> Nat.Even y -> Nat.Odd (x mod y).
-Proof.
- intros.
- rewrite Nat.mod_eq; trivial.
- generalize (Nat.Even_mul_l y (x / y) H1); intros HH2.
- apply Nat.odd_spec.
- rewrite Nat.odd_sub.
- - apply Nat.odd_spec in H0.
- rewrite H0.
- apply Nat.even_spec in HH2.
- now rewrite <- Nat.negb_even, HH2.
- - now apply Nat.mul_div_le.
-Qed.
-
-Lemma mod_even_even x y :
- y <> 0 ->
- Nat.Even x -> Nat.Even y -> Nat.Even (x mod y).
-Proof.
- intros.
- rewrite Nat.mod_eq; trivial.
- generalize (Nat.Even_mul_l y (x / y) H1); intros HH2.
- apply Nat.even_spec.
- rewrite Nat.even_sub.
- - apply Nat.even_spec in H0.
- rewrite H0.
- apply Nat.even_spec in HH2.
- now rewrite HH2.
- - now apply Nat.mul_div_le.
-Qed.
-
-Lemma odd_nth_root_div_pow_sum_0 j k n :
- (2*j+1) mod (2^(S n)) <> (2*k+1) mod (2 ^ (S n)) ->
- let w := Cdiv (nth_root (2*j+1) (2 ^ (S n))) (nth_root (2*k+1) (2 ^ (S n))) in
- list_Cplus (map (fun j => Cpow w j) (seq 0 (2^n))) = 0%R.
-Proof.
- intros.
- destruct (pow2_S n).
- rewrite H0.
- destruct (pow2_S (S n)).
- assert (nz:2 ^ (S n) <> 0) by lia.
- apply C_telescope_pow_0.
- - unfold w.
- rewrite H1.
- rewrite nth_root_div.
- apply nth_root_not_1.
- rewrite <- H1.
- intros HH.
- apply (f_equal Z.of_nat) in HH.
-
- autorewrite with of_nat_re in HH.
- + rewrite <- Zdiv.Zplus_mod_idemp_r in HH.
- rewrite <- Zdiv.Zminus_mod_idemp_l in HH.
- rewrite Zdiv.Z_mod_same_full in HH.
- rewrite Zdiv.Zminus_mod_idemp_r in HH.
- rewrite Zdiv.Zplus_mod_idemp_r in HH.
- apply H.
- apply Nat2Z.inj.
- autorewrite with of_nat_re.
- apply (f_equal (fun x => (x + (2 * Z.of_nat k + 1)) mod 2 ^ Z.of_nat (S n)))%Z in HH.
- rewrite Zplus_0_l in HH.
- rewrite <- HH.
- rewrite Zdiv.Zplus_mod_idemp_l.
- f_equal.
- lia.
- + apply Nat.lt_le_incl.
- apply Nat.mod_upper_bound.
- lia.
- - unfold w.
- rewrite H1.
- rewrite nth_root_div.
- rewrite Cpow_nth_root.
- apply nth_root_1.
- rewrite <- H0, <- H1.
- assert (exists (k2 : nat),
- (2 ^ S n - (2 * k + 1) mod 2^ S n = 2*k2 + 1)).
- {
- assert (odd1:Nat.Odd ((2 * k + 1) mod 2 ^ S n)).
- {
- apply mod_odd_even; trivial; red; eauto.
- exists (2 ^ n).
- simpl; lia.
- }
- apply Nat.odd_spec in odd1.
- apply Nat.odd_spec.
- rewrite Nat.odd_sub.
- - rewrite odd1.
- now rewrite Nat.odd_pow by congruence.
- - apply Nat.lt_le_incl.
- apply Nat.mod_upper_bound.
- lia.
- }
- destruct H2.
- rewrite H2.
- replace (2 * j + 1 + (2 *x1 + 1)) with (2 * (j + x1 + 1)) by lia.
- replace (2 ^ n * (2 * (j + x1 + 1))) with
- ((j + x1 + 1) * (2 ^ S n)).
- + apply Nat.mod_mul; lia.
- + simpl; lia.
-Qed.
-
-
-Lemma sum_pow_div (c1 c2 : C) n :
- c1 = c2 ->
- c2 <> 0%R ->
- list_Cplus (map (fun j => Cpow (Cdiv c1 c2) j) (seq 0 n)) = INR (n).
-Proof.
- intros.
- replace (map (fun j => Cpow (Cdiv c1 c2) j) (seq 0 n)) with
- (map (fun (j:nat) => RtoC 1%R) (seq 0 n)).
- - induction n.
- + easy.
- + rewrite seq_S.
- rewrite map_app.
- rewrite list_Cplus_app.
- rewrite IHn.
- rewrite S_INR.
- simpl.
- rewrite Cplus_0_r.
- now rewrite RtoC_plus.
- - apply map_ext.
- intros.
- rewrite H.
- unfold Cdiv.
- rewrite Cinv_r; trivial.
- now rewrite Cpow_1_l.
-Qed.
-
-Lemma odd_nth_root_div_pow_sum_1 j k n :
- j mod (2^(S n)) = k mod (2 ^ (S n)) ->
- let w := Cdiv (nth_root j (2 ^ (S n))) (nth_root k (2 ^ (S n))) in
- list_Cplus (map (fun j => Cpow w j) (seq 0 (2^n))) = INR (2^n).
-Proof.
- intros.
- unfold w.
- destruct (pow2_S (S n)).
- rewrite H0.
- apply sum_pow_div.
- - apply nth_root_mod.
- now rewrite <- H0.
- - apply nth_root_not_0.
-Qed.
-
-Lemma conj_transpose (m : list (list C)) :
- conj_mat (transpose_mat m) = transpose_mat (conj_mat m).
-Proof.
- unfold conj_mat, transpose_mat.
- rewrite map_map, map_length.
- apply map_ext; intros.
- do 2 rewrite map_map.
- apply map_ext; intros.
- replace (RtoC 0%R) with (Cconj 0%R) at 2.
- - now rewrite map_nth.
- - unfold Cconj, RtoC, fst, snd.
- f_equal; lra.
-Qed.
-
-Lemma V_conj_transpose {n} (m : Matrix C n n) :
- V_conj_mat (transpose m) = transpose (V_conj_mat m).
-Proof.
- unfold V_conj_mat, transpose.
- easy.
-Qed.
-
-(*
-Lemma transpose_involutive (m : list (list C)) :
- transpose_mat (transpose_mat m) = m.
-Proof.
- unfold transpose_mat.
- rewrite map_length.
- rewrite seq_length.
- apply nth_error_eqs; intros.
- rewrite nth_error_map.
- destruct (lt_dec i (length m)).
- - rewrite seq_nth_error; trivial.
- unfold option_map.
- rewrite nth_error_nth' with (d := nil); trivial.
- f_equal.
- rewrite map_map.
-
- admit.
- - unfold option_map.
- assert (length m <= i) by lia.
- apply nth_error_None in H.
- rewrite H.
- assert (i >= length (seq 0 (length m))).
- {
- rewrite seq_length; lia.
- }
- apply nth_error_None in H0.
- now rewrite H0.
- Admitted.
-*)
-
-(*
-Lemma decode_mat_encode_mat_on_diag (n : nat):
- let pmat := (peval_mat (odd_nth_roots (S n))) in
- let prod := mat_mat_mult pmat (conj_mat (transpose_mat pmat)) in
- forall n,
-(* n < length prod -> *)
- nth n (nth n prod nil) 0%R = RtoC (2^S n)%R.
-Proof.
- intros.
- unfold prod, mat_mat_mult.
- rewrite conj_transpose.
- rewrite transpose_involutive.
- unfold mat_vec_mult.
- unfold pmat, peval_mat, conj_mat.
- do 2 rewrite map_map.
-
-Admitted.
-*)
-
-Lemma V_transpose_involutive {T} {n1 n2} (m : Matrix T n1 n2) :
- transpose (transpose m) = m.
-Proof.
- now unfold transpose.
-Qed.
-
-Lemma nth_root_conj_alt j n :
- Cconj (nth_root j (S n)) = nth_root (S n - j mod (S n)) (S n).
-Proof.
- rewrite nth_root_conj.
- now rewrite nth_root_inv.
-Qed.
-
-Lemma vector_sum_list_Cplus {n} (v : Vector C n) :
- vector_sum v = list_Cplus (vector_to_list v).
-Proof.
- unfold vector_sum, vector_to_list.
- induction n.
- - unfold vector_fold_right.
- unfold vector_fold_right_dep.
- unfold vector_fold_right_bounded_dep.
- now simpl.
- - rewrite vector_fold_right_Sn.
- rewrite vector_fold_right_Sn.
- simpl.
- f_equal.
- now rewrite IHn.
- Qed.
-
-(*
-Lemma list_Cplus_vector_sum (l : list C) :
- list_Cplus l = vector_sum (list_to_vector l).
-Proof.
- unfold list_to_vector.
- unfold vector_sum.
- induction l.
- - unfold vector_fold_right, vector_fold_right_dep, vector_fold_right_bounded_dep.
- now simpl.
- - simpl.
- rewrite IHl.
- unfold list_fold_right_dep.
- unfold list_fold_right1_dep.
- Search vcons.
-Admitted.
-*)
-
-Lemma V_telescope_pow_0 (c : C) (n : nat) :
- c <> R1 ->
- Cpow c (S n) = 1%R ->
- vector_sum (fun (j : { j | j < S n}) => Cpow c (proj1_sig j)) = 0%R.
-Proof.
- intros.
- generalize (C_telescope_pow_0 c n H H0); intros.
- rewrite <- H1.
- rewrite vector_sum_list_Cplus.
- apply list_Cplus_perm_proper.
- unfold vector_to_list.
- assert (vector_fold_right cons nil
- (fun j : {j : nat | j < S n} => (c ^ proj1_sig j)%C) =
- (rev (map (fun j : nat => (c ^ j)%C) (seq 0 (S n))))).
- {
- clear H0 H1.
- induction n.
- - unfold vector_fold_right.
- unfold vector_fold_right_dep.
- unfold vector_fold_right_bounded_dep.
- now simpl.
- - rewrite vector_fold_right_Sn.
- rewrite seq_S.
- rewrite map_app.
- rewrite rev_app_distr.
- rewrite <- IHn.
- simpl.
- unfold vlast, proj1_sig.
- f_equal.
- }
- rewrite H2.
- symmetry.
- apply Permutation_rev.
-Qed.
-
-Lemma nat_mod_mul a b c :
- (a * c) <> 0 ->
- b mod c = 0 ->
- (a * b) mod (a * c) = 0.
-Proof.
- intros.
- generalize (Nat.div_mod_eq b c); intros.
- rewrite H0 in H1.
- replace (c * (b / c) + 0) with (c * (b / c)) in H1 by lia.
- rewrite H1.
- replace (a * (c * (b / c))) with ((b/c) * (a * c)) by lia.
- now rewrite Nat.mod_mul.
-Qed.
-
-Lemma V_decode_mat_encode_mat_off_diag (n : nat):
- let pmat := (V_peval_mat (V_odd_nth_roots (S n))) in
- let prod := V_mat_mat_mult pmat (V_conj_mat (transpose pmat)) in
- forall i j,
- proj1_sig i <> proj1_sig j ->
- prod i j = 0%R.
-Proof.
- intros.
- unfold prod.
- rewrite V_conj_transpose.
- unfold V_mat_mat_mult, V_conj_mat.
- rewrite V_transpose_involutive.
- unfold pmat.
- unfold V_inner_prod.
- unfold V_peval_mat.
- unfold V_odd_nth_roots.
- destruct i.
- destruct j.
- unfold proj1_sig in *.
- destruct (pow2_S (S n)).
- rewrite H0.
- destruct (pow2_S (S (S n))).
- rewrite H1.
- generalize (V_telescope_pow_0 (Cmult (nth_root (2 * x + 1) (S x2))
- (Cconj (nth_root (2 * x0 + 1) (S x2))))
- x1
- ); intros.
- rewrite <- H2.
- - f_equal.
- apply vec_eq_eq; intros x3.
- simpl.
- now rewrite Cpow_mult_l, Cpow_conj.
- - rewrite nth_root_conj_alt.
- rewrite nth_root_mul.
- apply nth_root_not_1.
- rewrite <- H1.
- assert (2 * x + 1 < 2 ^ (S (S n))).
- {
- simpl.
- simpl in l.
- lia.
- }
- assert (2 * x0 + 1 < 2 ^ (S (S n))).
- {
- simpl.
- simpl in l0.
- lia.
- }
- intros HH.
- apply (f_equal Z.of_nat) in HH.
- autorewrite with of_nat_re in HH; [|rewrite Nat.mod_small; lia].
- rewrite <- Zdiv.Zplus_mod_idemp_r in HH.
- rewrite <- Zdiv.Zminus_mod_idemp_l in HH.
- rewrite Zdiv.Z_mod_same_full in HH.
- rewrite Zdiv.Zminus_mod_idemp_r in HH.
- rewrite Zdiv.Zplus_mod_idemp_r in HH.
- apply H.
- apply Nat2Z.inj.
- autorewrite with of_nat_re.
- replace (2 * Z.of_nat x + 1 + (0 - (2 * Z.of_nat x0 + 1)))%Z with
- (2 * Z.of_nat x - 2 * Z.of_nat x0)%Z in HH by lia.
- apply (f_equal (fun x => (x + (2 * Z.of_nat x0)) mod 2 ^ Z.of_nat (S (S n))))%Z in HH.
- rewrite Zplus_0_l in HH.
- rewrite Zdiv.Zplus_mod_idemp_l in HH.
- replace (2 * Z.of_nat x - 2 * Z.of_nat x0 + 2 * Z.of_nat x0)%Z with
- (2 * Z.of_nat x)%Z in HH by lia.
- rewrite Z.mod_small in HH.
- + rewrite Z.mod_small in HH; try lia.
- split; try lia.
- apply inj_lt in l0.
- rewrite Nat2Z.inj_pow in l0.
- replace (2 ^ Z.of_nat (S (S n)))%Z with
- (2 * 2 ^ Z.of_nat (S n))%Z; try lia.
- rewrite (Nat2Z.inj_succ (S n)).
- rewrite Z.pow_succ_r; lia.
- + split; try lia.
- apply inj_lt in l.
- rewrite Nat2Z.inj_pow in l.
- replace (2 ^ Z.of_nat (S (S n)))%Z with
- (2 * 2 ^ Z.of_nat (S n))%Z; try lia.
- rewrite (Nat2Z.inj_succ (S n)).
- rewrite Z.pow_succ_r; lia.
- - rewrite nth_root_conj_alt.
- rewrite nth_root_mul.
- generalize nth_root_1; intros.
- rewrite Cpow_nth_root.
- apply nth_root_1.
- rewrite <- H0, <- H1.
- replace (2 ^ (S (S n))) with ((2 ^ S n)*2) by (simpl;lia).
- rewrite nat_mod_mul; try lia.
- rewrite Nat.add_mod; try lia.
- assert ((2 * x + 1) mod 2 = 1).
- {
- rewrite Nat.add_mod; try lia.
- replace (2 * x) with (x * 2) by lia.
- rewrite Nat.mod_mul; try lia.
- now simpl.
- }
- rewrite H4.
- assert (exists (k : nat),
- (2 ^ S n * 2 - (2 * x0 + 1) mod (2^ S n * 2) = 2 * k + 1)).
- {
- assert (odd1:Nat.Odd ((2 * x0 + 1) mod (2 ^ S n * 2))).
- {
- apply mod_odd_even; trivial; red; eauto; try lia.
- exists (2 ^ S n).
- simpl; lia.
- }
- apply Nat.odd_spec in odd1.
- apply Nat.odd_spec.
- rewrite Nat.odd_sub.
- - rewrite odd1.
- rewrite Nat.odd_mul.
- now rewrite Nat.odd_pow by congruence.
- - apply Nat.lt_le_incl.
- apply Nat.mod_upper_bound.
- lia.
- }
- destruct H5.
- rewrite H5.
- replace (2 * x3 + 1) with (1 + x3 * 2) by lia.
- replace ((1 + x3 * 2) mod 2) with 1; [now simpl | ].
- rewrite Nat.add_mod; try lia.
- rewrite Nat.mod_mul; try lia.
- simpl; lia.
- Qed.
-
-Lemma root_conj_power_inv i j n :
- Cmult (Cpow (nth_root i (S n)) j)
- (Cconj (Cpow (nth_root i (S n)) j)) = 1%R.
-Proof.
- Search nth_root.
- rewrite Cpow_nth_root.
- now rewrite mult_conj_root.
-Qed.
-
-Lemma V_decode_mat_encode_mat_on_diag (n : nat):
- let pmat := (V_peval_mat (V_odd_nth_roots (S n))) in
- let prod := V_mat_mat_mult pmat (V_conj_mat (transpose pmat)) in
- forall n0,
- prod n0 n0 = RtoC (2^S n)%R.
-Proof.
- intros.
- unfold prod.
- rewrite V_conj_transpose.
- unfold V_mat_mat_mult, V_conj_mat.
- rewrite V_transpose_involutive.
- unfold pmat.
- unfold V_inner_prod.
- replace (fun n' : {n' : nat | n' < 2 ^ S n} =>
- (V_peval_mat (V_odd_nth_roots (S n)) n0 n' *
- Cconj (V_peval_mat (V_odd_nth_roots (S n)) n0 n'))%C) with
- (ConstVector (2 ^ S n) (RtoC 1%R)).
- - rewrite vector_sum_const.
- rewrite Cmult_1_r.
- f_equal.
- now rewrite pow_INR.
- - apply vec_eq_eq.
- intros ?.
- unfold ConstVector.
- unfold V_peval_mat.
- unfold V_odd_nth_roots.
- destruct (pow2_S (S (S n))).
- rewrite H.
- now rewrite root_conj_power_inv.
- Qed.
-
-(*
-Lemma decode_mat_encode_mat_off_diag (n : nat):
- let pmat := (peval_mat (odd_nth_roots (S n))) in
- let prod := mat_mat_mult pmat (conj_mat (transpose_mat pmat)) in
- forall i j,
-(* i < length prod ->
- j < length prod ->
-*)
- i <> j ->
- nth i (nth j prod nil) 0%R = 0%R.
-Proof.
- intros.
- unfold prod, mat_mat_mult.
- rewrite conj_transpose.
- rewrite transpose_involutive.
- unfold mat_vec_mult.
- unfold pmat, peval_mat, conj_mat.
- do 2 rewrite map_map.
-
-
-Admitted.
-
-
-Lemma decode_mat_encode_mat (cl : list C) (n : nat):
- length cl = length (odd_nth_roots (S n)) ->
- let pmat := (peval_mat (odd_nth_roots (S n))) in
- mat_vec_mult (mat_mat_mult pmat (conj_mat (transpose_mat pmat))) cl =
- map (fun c => Cmult c (2^(S n))%R) cl.
-Proof.
- intros.
- unfold mat_vec_mult, mat_mat_mult.
- rewrite conj_transpose.
- rewrite transpose_involutive.
- rewrite map_map.
- unfold mat_vec_mult.
- unfold pmat.
- Admitted.
-*)
-
-Lemma vector_sum_all_but_1_0 n (i : {i : nat | i < n}) c :
- vector_sum (fun (n' : {n' : nat | n' < n}) => if eq_nat_decide (proj1_sig i) (proj1_sig n') then c else 0%R) = c.
-Proof.
- unfold vector_sum.
- induction n.
- - unfold vector_fold_right.
- unfold vector_fold_right_dep.
- unfold vector_fold_right_bounded_dep.
- simpl.
- destruct i.
- lia.
- - rewrite vector_fold_right_Sn.
- destruct (Compare_dec.lt_dec (proj1_sig i) n).
- + unfold vlast, proj1_sig.
- destruct i.
- unfold proj1_sig in l.
- match_destr.
- * apply eq_nat_eq in e.
- lia.
- * specialize (IHn (exist _ x l)).
- rewrite Cplus_0_l.
- unfold vdrop_last.
- rewrite <- IHn at 1.
- apply vector_fold_right_ext.
- unfold vec_eq.
- intros.
- now destruct i.
- + unfold vlast, proj1_sig.
- destruct i.
- unfold proj1_sig in n0.
- assert (x = n) by lia.
- rewrite H.
- unfold vdrop_last.
- match_destr.
- * replace (vector_fold_right Cplus 0%R
- (fun H0 : {n' : nat | n' < n} =>
- let (x0, _) := H0 in if eq_nat_decide n x0 then c else 0%R))%C
- with (RtoC (0%R)).
- -- now rewrite Cplus_0_r.
- -- assert (RtoC (0%R) = vector_sum (ConstVector n (RtoC (0%R)))).
- {
- rewrite vector_sum_const.
- now rewrite Cmult_0_r.
- }
- rewrite H0 at 1.
- unfold vector_sum, ConstVector.
- apply vector_fold_right_ext.
- unfold vec_eq.
- intros.
- destruct i.
- match_destr.
- apply eq_nat_eq in e0; lia.
- * now generalize (eq_nat_refl n); intros.
- Qed.
-
-Lemma index_eq {n} (i j : {n' : nat | n' < n}) :
- proj1_sig i = proj1_sig j ->
- i = j.
-Proof.
- intros.
- destruct i.
- destruct j.
- unfold proj1_sig in H.
- subst.
- apply index_pf_irrel.
-Qed.
-
-Lemma V_decode_mat_encode_mat_assoc_l (n : nat) (cl : Vector C (2^(S n))) :
- let pmat := (V_peval_mat (V_odd_nth_roots (S n))) in
- let prod := V_mat_mat_mult pmat (V_conj_mat (transpose pmat)) in
- V_mat_vec_mult prod cl = Vscale (RtoC (2^S n)%R) cl.
-Proof.
- generalize (V_decode_mat_encode_mat_on_diag n); intros.
- generalize (V_decode_mat_encode_mat_off_diag n); intros.
- simpl in H; simpl in H0.
- unfold prod, pmat.
- apply vec_eq_eq; intros x.
- unfold V_mat_vec_mult, Vscale.
- unfold V_inner_prod.
- generalize (vector_sum_all_but_1_0 (2 ^ S n) x ((RtoC (2 ^ S n)%R) * cl x)); intros.
- rewrite <- H1.
- f_equal.
- apply vec_eq_eq; intros ?.
- clear H1.
- destruct (eq_nat_decide (proj1_sig x) (proj1_sig i)).
- - apply eq_nat_eq in e.
- specialize (H x).
- generalize (index_eq x i e); intros.
- rewrite <- H1.
- apply (f_equal (fun z => (z * cl x)%C)) in H.
- replace (2 ^ S n)%R with (2 * 2^n)%R by now simpl.
- rewrite <- H.
- f_equal.
- - specialize (H0 x i).
- cut_to H0.
- + apply (f_equal (fun z => (z * cl i)%C)) in H0.
- rewrite Cmult_0_l in H0.
- rewrite <- H0.
- f_equal.
- + unfold not; intros.
- now apply eq_eq_nat in H1.
-Qed.
-
-Lemma V_decode_mat_encode_mat (n : nat) (cl : Vector C (2^(S n))) :
- let pmat := (V_peval_mat (V_odd_nth_roots (S n))) in
- let encmat := (V_conj_mat (transpose pmat)) in
- V_mat_vec_mult pmat (V_mat_vec_mult encmat cl) = Vscale (RtoC (2^S n)%R) cl.
-Proof.
- unfold Vector in cl.
- generalize (V_decode_mat_encode_mat_assoc_l n cl); intros.
- rewrite <- H.
- now rewrite V_mmv_mult_assoc.
-Qed.
-
-Definition diag_matrix {n} (v : Vector C n) : Matrix C n n :=
- (fun i j => if eq_nat_decide (proj1_sig j) (proj1_sig i) then (v i) else 0%R).
-
-(* shows evaluation can be done by modified FFT of half size*)
-Lemma V_peval_mat_prod (n : nat) :
- V_peval_mat (V_odd_nth_roots (S n)) =
- V_mat_mat_mult (V_peval_mat (V_even_nth_roots (S n)))
- (diag_matrix (V_nth_roots_half (S n))).
-Proof.
- apply vec_eq_eq; intros ?.
- apply vec_eq_eq; intros ?.
- unfold V_peval_mat, diag_matrix, V_mat_mat_mult.
- unfold V_inner_prod, transpose.
- generalize (vector_sum_all_but_1_0 (2 ^ S n) i0 (V_odd_nth_roots (S n) i ^ proj1_sig i0)%C); intros.
- rewrite <- H.
- f_equal.
- apply vec_eq_eq; intros ?.
- match_destr.
- - unfold V_odd_nth_roots, V_even_nth_roots, V_nth_roots_half.
- destruct (pow2_S (S (S n))).
- rewrite H0.
- do 2 rewrite Cpow_nth_root.
- rewrite nth_root_mul.
- f_equal.
- replace (proj1_sig i1) with (proj1_sig i0); try lia.
- now apply eq_nat_eq.
- - now rewrite Cmult_0_r.
-Qed.
-
-(* shows enconding can be done by modified IFFT of half size*)
-Lemma V_encode_mat_prod (n : nat) :
- let pmat := (V_peval_mat (V_odd_nth_roots (S n))) in
- let encmat := (V_conj_mat (transpose pmat)) in
- encmat =
- V_mat_mat_mult
- (diag_matrix (vmap' Cconj (V_nth_roots_half (S n))))
- (V_peval_mat (vmap' Cconj (V_even_nth_roots (S n)))).
-Proof.
- apply vec_eq_eq; intros ?.
- apply vec_eq_eq; intros ?.
- unfold V_peval_mat, diag_matrix, V_mat_mat_mult, V_conj_mat.
- unfold V_inner_prod, transpose.
- generalize (vector_sum_all_but_1_0 (2 ^ S n) i (Cconj (V_odd_nth_roots (S n) i0 ^ proj1_sig i))); intros.
- rewrite <- H.
- f_equal.
- apply vec_eq_eq; intros ?.
- match_destr.
- - assert (eq_nat (proj1_sig i1) (proj1_sig i)).
- {
- apply eq_nat_eq in e.
- apply eq_eq_nat; lia.
- }
- match_destr; try congruence.
- unfold vmap', V_odd_nth_roots, V_even_nth_roots, V_nth_roots_half.
- destruct (pow2_S (S (S n))).
- rewrite H1.
- rewrite <- Cpow_conj.
- rewrite <- Cmult_conj.
- f_equal.
- do 2 rewrite Cpow_nth_root.
- rewrite nth_root_mul.
- f_equal.
- replace (proj1_sig i1) with (proj1_sig i); try lia.
- now apply eq_nat_eq.
- - assert (~ eq_nat (proj1_sig i1) (proj1_sig i)).
- {
- unfold not; intros.
- apply eq_nat_eq in H0.
- rewrite H0 in n0.
- now generalize (eq_nat_refl (proj1_sig i)).
- }
- match_destr; try congruence.
- now rewrite Cmult_0_l.
-Qed.
-
-Lemma nth_root_even_half (i n : nat) :
- nth_root (2 * i) (2 * (S n)) = nth_root i (S n).
-Proof.
- unfold nth_root.
- do 2 rewrite mult_INR.
- generalize (S_INR_not_0 n); intros.
- generalize (S_INR_not_0 1); intros.
- do 2 f_equal; field; split; lra.
-Qed.
-
-Program Definition index_reflect {n} (n' : {n' : nat | n' < n}) : {n' : nat | n' < n} :=
- (exist _ (n - 1 - (proj1_sig n')) _).
-Next Obligation.
- lia.
-Qed.
-
-Lemma index_reflect_involutive {n} (i : {i' : nat | i' < n}) :
- index_reflect (index_reflect i) = i.
-Proof.
- unfold index_reflect.
- apply index_eq.
- destruct i.
- unfold proj1_sig; lia.
-Qed.
-
-Definition vector_rev {n} {T} (v : Vector T n) :=
- fun i => v (index_reflect i).
-
-Definition vector_rev_conj {n} (v : Vector C n) :=
- forall i,
- v i = Cconj (v (index_reflect i)).
-
-Lemma vector_rev_conj_plus {n} (v1 v2 : Vector C n) :
- vector_rev_conj v1 ->
- vector_rev_conj v2 ->
- vector_rev_conj (vmap' (fun '(a,b) => Cplus a b) (vector_zip v1 v2)).
-Proof.
- unfold vector_rev_conj; intros.
- unfold vector_zip, vmap'.
- rewrite H, H0.
- now rewrite Complex.Cplus_conj.
-Qed.
-
-Lemma vector_rev_conj_mult {n} (v1 v2 : Vector C n) :
- vector_rev_conj v1 ->
- vector_rev_conj v2 ->
- vector_rev_conj (vmap' (fun '(a,b) => Cmult a b) (vector_zip v1 v2)).
-Proof.
- unfold vector_rev_conj; intros.
- unfold vector_zip, vmap'.
- rewrite H, H0.
- now rewrite Cmult_conj.
-Qed.
-
-Lemma vector_rev_conj_scale {n} (r : R) (v : Vector C n) :
- vector_rev_conj v ->
- vector_rev_conj (Vscale (RtoC r) v).
-Proof.
- unfold vector_rev_conj; intros.
- unfold Vscale.
- rewrite H.
- rewrite Cmult_conj.
- f_equal.
- unfold Cconj, fst, snd, RtoC.
- f_equal; lra.
-Qed.
-
-Lemma vector_rev_conj_const_R n (r : R) :
- vector_rev_conj (ConstVector n (RtoC r)).
-Proof.
- unfold vector_rev_conj, ConstVector, RtoC, Cconj, fst, snd; intros.
- f_equal; lra.
-Qed.
-
-Lemma vector_rev_conj_conj {n} (v : Vector C n) :
- vector_rev_conj v ->
- vector_rev_conj (vmap' Cconj v).
-Proof.
- unfold vector_rev_conj, vmap'; intros.
- now rewrite H.
-Qed.
-
-Lemma vector_rev_conj_Cpow {n} i (v : Vector C n) :
- vector_rev_conj v ->
- vector_rev_conj (vmap' (fun c => Cpow c i) v).
-Proof.
- unfold vector_rev_conj, vmap'; intros.
- rewrite H.
- now rewrite Cpow_conj.
-Qed.
-
-Lemma map_Cconj_vector_to_list {n} (v : Vector C n) :
- map Cconj (vector_to_list v) = vector_to_list (vmap' Cconj v).
-Proof.
- unfold vector_to_list.
- induction n.
- - unfold vector_fold_right, vector_fold_right_dep, vector_fold_right_bounded_dep.
- now simpl.
- - do 2 rewrite vector_fold_right_Sn.
- rewrite map_cons.
- rewrite IHn.
- f_equal.
- Qed.
-
-Lemma rev_vector_to_list {n} {T} (v : Vector T n) :
- rev (vector_to_list v) = vector_to_list (vector_rev v).
-Proof.
- unfold vector_rev.
- unfold vector_to_list.
- induction n.
- - unfold vector_fold_right, vector_fold_right_dep, vector_fold_right_bounded_dep.
- now simpl.
- - do 2 rewrite vector_fold_right_Sn.
-
-Admitted.
-
-Lemma vector_rev_conj_sum {n} (v : Vector C n) :
- vector_rev_conj v ->
- Im (vector_sum v) = 0%R.
-Proof.
- rewrite vector_sum_list_Cplus.
- intros.
- apply list_Cplus_conj_rev.
- rewrite map_Cconj_vector_to_list.
- rewrite rev_vector_to_list.
- f_equal.
- apply vec_eq_eq; intros ?.
- assert (vector_rev_conj (vmap' Cconj v)).
- {
- now apply vector_rev_conj_conj.
- }
- rewrite H0.
- unfold vmap'.
- now rewrite Cconj_conj.
-Qed.
-
-Lemma vector_rev_conj_inner {n} (v1 v2 : Vector C n) :
- vector_rev_conj v1 ->
- vector_rev_conj v2 ->
- Im (V_inner_prod v1 v2) = 0%R.
-Proof.
- intros.
- apply vector_rev_conj_sum.
- now apply vector_rev_conj_mult.
-Qed.
-
-Lemma vector_cplus_comm {n} (v1 v2 : Vector C n) :
- (vmap' (fun '(a,b) => Cplus a b) (vector_zip v1 v2)) =
- (vmap' (fun '(a,b) => Cplus a b) (vector_zip v2 v1)).
-Proof.
- unfold vmap', vector_zip.
- apply vec_eq_eq; intros ?.
- apply Cplus_comm.
-Qed.
-
-Lemma vector_cmult_comm {n} (v1 v2 : Vector C n) :
- (vmap' (fun '(a,b) => Cmult a b) (vector_zip v1 v2)) =
- (vmap' (fun '(a,b) => Cmult a b) (vector_zip v2 v1)).
-Proof.
- unfold vmap', vector_zip.
- apply vec_eq_eq; intros ?.
- apply Cmult_comm.
-Qed.
-
-Lemma vector_cplus_assoc {n} (v1 v2 v3 : Vector C n) :
- vmap' (fun '(a,b) => Cplus a b) (vector_zip v1 (vmap' (fun '(a,b) => Cplus a b) (vector_zip v2 v3))) =
- vmap' (fun '(a,b) => Cplus a b) (vector_zip (vmap' (fun '(a,b) => Cplus a b) (vector_zip v1 v2)) v3).
-Proof.
- unfold vmap', vector_zip.
- apply vec_eq_eq; intros ?.
- apply Cplus_assoc.
-Qed.
-
-Lemma vector_cmult_assoc {n} (v1 v2 v3 : Vector C n) :
- vmap' (fun '(a,b) => Cmult a b) (vector_zip v1 (vmap' (fun '(a,b) => Cmult a b) (vector_zip v2 v3))) =
- vmap' (fun '(a,b) => Cmult a b) (vector_zip (vmap' (fun '(a,b) => Cmult a b) (vector_zip v1 v2)) v3).
-Proof.
- unfold vmap', vector_zip.
- apply vec_eq_eq; intros ?.
- apply Cmult_assoc.
-Qed.
-
-Lemma vector_rev_conj_odd_nth_roots (n : nat) :
- vector_rev_conj (V_odd_nth_roots (S n)).
-Proof.
- unfold vector_rev_conj, V_odd_nth_roots.
- intros.
- destruct i.
- unfold index_reflect, proj1_sig.
- destruct (pow2_S (S (S n))).
- rewrite H.
- rewrite nth_root_conj_alt.
- f_equal.
- rewrite <- H.
- replace (2^S (S n)) with (2 * 2^S n) by (simpl; lia).
- rewrite Nat.mod_small; lia.
-Qed.
-
-Lemma V_mat_encode_real (n : nat) (cl : Vector C (2^(S n))) :
- let pmat := (V_peval_mat (V_odd_nth_roots (S n))) in
- let encmat := (V_conj_mat (transpose pmat)) in
- vector_rev_conj cl ->
- forall i,
- Im ((V_mat_vec_mult encmat cl) i) = 0%R.
-Proof.
- unfold V_mat_vec_mult, transpose, V_peval_mat, V_conj_mat.
- intros.
- apply vector_rev_conj_inner; trivial.
- apply (vector_rev_conj_conj (vmap' (fun c => Cpow c (proj1_sig i)) (V_odd_nth_roots (S n)))).
- apply vector_rev_conj_Cpow, vector_rev_conj_odd_nth_roots.
-Qed.
-
-Lemma V_mat_encode_real_alt (n : nat) (cl : Vector C (2^(S n))) :
- let pmat := (V_peval_mat (V_odd_nth_roots (S n))) in
- let encmat := (V_conj_mat (transpose pmat)) in
- vector_rev_conj cl ->
- V_mat_vec_mult encmat cl = vmap' RtoC (vmap' Re (V_mat_vec_mult encmat cl)).
-Proof.
- intros.
- apply vec_eq_eq; intros ?.
- apply Re_Im.
- now apply V_mat_encode_real.
-Qed.
-
-
diff --git a/coq/FHE/zp_prim_root.v b/coq/FHE/zp_prim_root.v
deleted file mode 100644
index a8c81da0..00000000
--- a/coq/FHE/zp_prim_root.v
+++ /dev/null
@@ -1,3070 +0,0 @@
-Require Import Lia.
-From mathcomp Require Import all_ssreflect zmodp poly ssralg cyclic fingroup finalg ring seq bigop.
-Require Import encode.
-
-Set Implicit Arguments.
-Unset Strict Implicit.
-Unset Printing Implicit Defensive.
-Set Bullet Behavior "Strict Subproofs".
-
-(* next 3 lemmas are copied from mathcomp_extra rsa *)
-
-(* This should be part of the standard library *)
-
-Lemma prime_modn_expSn p n : prime p -> n.+1 ^ p = (n ^ p).+1 %[mod p].
-Proof.
-case: p => // p pP.
-rewrite -[(_ ^ _).+1]addn0 (expnDn 1) big_ord_recr big_ord_recl /=.
-rewrite subnn binn exp1n !mul1n addnAC -modnDmr; congr ((_ + _) %% _).
-apply/eqP/dvdn_sum => -[i ?] _; exact/dvdn_mulr/prime_dvd_bin.
-Qed.
-
-(* This should be part of the standard library *)
-
-Lemma fermat_little a p : prime p -> a ^ p = a %[mod p].
-Proof.
-move=> pP.
-elim: a => [|a IH]; first by rewrite exp0n // prime_gt0.
-by rewrite prime_modn_expSn // -addn1 -modnDml IH modnDml addn1.
-Qed.
-
-(* This should be part of the standard library *)
-
-Lemma fermat_little_pred a p : prime p -> ~(p %| a) -> a ^ p.-1 = 1 %[mod p].
-Proof.
-move=> Pp pNDa.
-have a_gt0 : 0 < a by case: a pNDa.
-have aCp : coprime a p by rewrite coprime_sym prime_coprime //; apply/negP.
-have aE : (egcdn a p).1 * a = 1 %[mod p].
- by case: egcdnP => //= km kn -> _; rewrite (eqP aCp) modnMDl.
-rewrite -[_^ _]muln1 -modnMmr -[in LHS]aE // modnMmr.
-rewrite mulnC -mulnA -expnS prednK ?prime_gt0 //.
-by rewrite -modnMmr fermat_little // modnMmr aE.
-Qed.
-
-Import ssralg.GRing.
-
-Section cyclic.
-Local Open Scope ring_scope.
-
-Variable p : nat.
-
-Lemma prime_pbig2 (q : nat) :
- prime q ->
- 0 < (q.-1)%N.
-Proof.
- move=> p_prime.
- move: (p_prime) => /(prime_gt0).
- destruct q; [| destruct q].
- - by rewrite ltnn.
- - by inversion p_prime.
- - by rewrite ltn0Sn.
-Qed.
-
- Lemma Fp_exp_expn (x : 'F_p) (n : nat):
- prime p ->
- nat_of_ord (x ^+ n)%R = x ^ n %% p.
- Proof.
- move=> p_prime.
- have peqq := (Fp_cast p_prime).
- induction n.
- - by rewrite /= {1}peqq expn0.
- - rewrite expnS exprS /mul /= IHn -modnMm.
- rewrite {2 4 5}peqq.
- by rewrite -(modnMm x _) modn_mod.
- Qed.
-
- Lemma zp_prim_root_max :
- prime p ->
- { w : 'F_p | (p.-1).-primitive_root w}.
- Proof.
- move=> p_prime.
- have pbig := (prime_pbig2 p_prime).
- have/(nth_find 0)-HH: has (p.-1).-primitive_root [seq x <- enum 'F_p | x != Zp0].
- {
- apply (@has_prim_root _ _ [seq x <- enum 'F_p | x != Zp0]) => //=.
- - rewrite all_filter.
- apply/allP => /= x xin.
- apply/implyP=> xn0.
- rewrite unity_rootE.
- have/(fermat_little_pred p_prime)-eqq: ~ p %| x.
- {
- have xltp: x < p.
- {
- move: (ltn_ord x).
- by rewrite {2}(Fp_cast p_prime).
- }
- have xpos: (0 < x) by by rewrite lt0n.
- move/(dvdn_leq xpos); lia.
- }
- apply /eqP.
- apply val_inj => /=.
- by rewrite {2}(Fp_cast p_prime) -eqq Fp_exp_expn.
- - apply filter_uniq.
- apply enum_uniq.
- - rewrite -rem_filter; [| apply enum_uniq].
- rewrite size_rem.
- + by rewrite -cardE (card_Fp p_prime).
- + by rewrite enumT Finite.EnumDef.enumDef /= mem_ord_enum.
- }
- by exists ([seq x <- enum 'F_p | x != 0]`_(find (p.-1).-primitive_root [seq x <- enum 'F_p | x != 0])).
- Qed.
-
- Lemma zp_prim_root (n : nat) :
- n > 0 ->
- prime p ->
- n %| p.-1 ->
- { w : 'F_p | n.-primitive_root w}.
- Proof.
- intros npos prim div.
- destruct (zp_prim_root_max prim).
- generalize (dvdn_prim_root i div); intros.
- by exists (exp x (p.-1 %/ n)).
- Qed.
-
-Lemma inZp_add j k n :
- inZp (j + k) = inZp j + inZp k :> 'Z_n.
-Proof.
- apply: val_inj => /=.
- by rewrite modnDm.
-Qed.
-
-Lemma inZp_mul j k n :
- inZp (j * k) = inZp j * inZp k :> 'Z_n.
-Proof.
- apply: val_inj => /=.
- by rewrite modnMm.
-Qed.
-
-Lemma inZp_exp j k n :
- inZp (j ^ k) = inZp j ^+ k :> 'Z_n.
-Proof.
- induction k.
- - rewrite expn0 expr0.
- by apply: val_inj => /=.
- - rewrite exprS expnS -IHk.
- apply inZp_mul.
-Qed.
-
-End cyclic.
-
-Require Import ssrbool.
-
-
-Section chinese.
-
- (* pairs represent (residue, modulus) *)
- Fixpoint chinese_list (l : seq (nat * nat)) : nat :=
- match l with
- | nil => 0
- | a :: nil => a.1
- | a :: l' =>
- chinese (a.2) (\prod_(i <- (map snd l')) i)
- (a.1) (chinese_list l')
- end.
-
- Lemma all_coprime_prod (a : nat) (l : seq nat) :
- all (coprime a) l ->
- coprime a (\prod_(i <- l) i).
- Proof.
- intros.
- rewrite big_seq.
- apply big_rec.
- - apply coprimen1.
- - intros.
- move/allP/(_ _ H0): H.
- by rewrite coprimeMr H1 => ->.
- Qed.
-
- Lemma chinese_remainder_list_cons_l (a : nat * nat) (l : list (nat * nat)) :
- all (coprime a.2) (map snd l) ->
- chinese_list (a::l) == a.1 %[mod a.2].
- Proof.
- induction l=> //= HH.
- rewrite big_cons.
- destruct l; trivial.
- - rewrite chinese_modl //.
- rewrite big_nil muln1.
- by move/andP: HH => [-> _].
- - rewrite chinese_modl //.
- move/andP: HH => [HH1 HH2/=].
- by rewrite coprimeMr HH1 /= all_coprime_prod.
- Qed.
-
- Lemma pairwise_coprime_cons a l :
- pairwise coprime (a :: l) ->
- all (coprime a) l.
- Proof.
- simpl.
- move /andP.
- tauto.
- Qed.
-
- Lemma chinese_reminder_list_cons_r (a : nat * nat) (l : list (nat * nat)) :
- pairwise coprime (map snd (a::l)) ->
- let m := \prod_(i <- map snd l) i in
- chinese_list (a::l) == chinese_list l %[mod m].
- Proof.
- intros.
- simpl.
- destruct l; trivial.
- - rewrite big_nil /= modn1 //.
- - rewrite chinese_modr //.
- apply all_coprime_prod.
- by apply pairwise_coprime_cons.
- Qed.
-
- Lemma symmetricE {A} (f:A->A->bool) : (ssrbool.symmetric f) <-> (RelationClasses.Symmetric f).
- Proof.
- rewrite /symmetric /RelationClasses.Symmetric.
- split; intros.
- - by rewrite H.
- - case_eq (f x y)=> eqq.
- + by rewrite H.
- + case_eq (f y x)=> eqq2//.
- by rewrite (H _ _ eqq2) in eqq.
- Qed.
-
- Lemma allE {A} (P:pred A) (l:list A) : all P l <-> List.Forall (fun x : A => P x) l.
- Proof.
- elim: l => /=.
- - split=>// _.
- - move=> a l IHl.
- split.
- + move/andP=> [ap pairP].
- constructor; tauto.
- + inversion 1; subst.
- by rewrite H2 IHl.
- Qed.
-
- Lemma pairwiseE {A} (P:rel A) (l:list A) : pairwise P l <-> List.ForallOrdPairs P l.
- Proof.
- elim: l => /=.
- - split=>// _.
- constructor.
- - move=> a l IHl.
- split.
- + move/andP=> [ap pairP].
- constructor;[| tauto].
- by apply allE.
- + inversion 1; subst.
- apply/andP; split.
- * by apply allE.
- * by apply IHl.
- Qed.
-
- Lemma chinese_remainder_list_split :
- forall (l1 l2 l : list (nat * nat)) (p : nat * nat),
- pairwise coprime (map snd l) ->
- l = l1 ++ (p::nil) ++ l2 ->
- chinese_list l == p.1 %[mod p.2].
- Proof.
- induction l1; intros.
- - simpl in H0.
- rewrite H0.
- rewrite H0 in H.
- rewrite chinese_remainder_list_cons_l //.
- simpl in H.
- by move/andP: H => [-> _].
- - pose (l3 := l1 ++ [:: p] ++ l2).
- have ->/=: l = a :: l3 by by rewrite H0.
- case_eq l3.
- + subst l3; destruct l1; simpl; congruence.
- + move=> _ _ <-.
- have pc': pairwise coprime [seq i.2 | i <- l3].
- {
- rewrite H0/= in H.
- by move/andP: H => [_ ->].
- }
- specialize (IHl1 l2 l3 p pc' (Logic.eq_refl _)).
- rewrite <- (eqP IHl1).
- have cp: coprime (a.2) (\prod_(i <- (map snd l3)) i).
-
- {
- rewrite all_coprime_prod //.
- subst l3.
- simpl in H.
- rewrite H0/= in H.
- by move/andP: H => [-> _].
- }
- move/eqP: (chinese_modr cp (a.1) (chinese_list l3)).
- have ->: \prod_(i <- [seq i.2 | i <- l3]) i = \prod_(i <- [seq i.2 | i <- p::(l1++l2)]) i.
- {
- apply perm_big_AC.
- - apply mulnA.
- - apply mulnC.
- - subst l3.
- apply perm_map.
- by rewrite perm_catCA perm_refl.
- }
- simpl.
- rewrite big_cons.
- by apply modn_muln.
- Qed.
-
- Lemma chinese_remainder_list (l : list (nat * nat)) :
- pairwise coprime (map snd l) ->
- forall p,
- p \in l ->
- chinese_list l == p.1 %[mod p.2].
- Proof.
- intros.
- case/splitPr: H0 H=>l1 l2 HH.
- by eapply chinese_remainder_list_split => //=.
- Qed.
-
- Lemma chinese_remainder_list_unique (a b : nat) (l : list nat) :
- pairwise coprime l ->
- (forall p,
- p \in l -> a == b %[mod p]) ->
- a == b %[mod \prod_(i <- l) i].
- Proof.
- induction l; simpl; intros.
- - by rewrite big_nil !modn1.
- - move /andP in H.
- destruct H.
- rewrite big_cons chinese_remainder.
- + apply /andP.
- split.
- * apply H0.
- rewrite in_cons.
- apply /orP.
- by left.
- * apply IHl; trivial.
- intros.
- apply H0.
- rewrite in_cons.
- apply /orP.
- by right.
- + by apply all_coprime_prod.
- Qed.
-
- Definition balanced_chinese_list (l : seq (nat * nat)) :=
- \sum_(p <- l)
- (\prod_(q <- l | q != p) q.2) *
- ((p.1 * (egcdn (\prod_(q <- l | q!= p) q.2) p.2).1) %% p.2).
-
- Lemma egcd_coprime (a b : nat) :
- 0 < a ->
- coprime a b ->
- (egcdn a b).1 * a = 1 %[mod b].
- Proof.
- intros.
- case: egcdnP => //= km kn ->.
- by rewrite (eqP H0) -modnDm modnMl add0n modn_mod.
- Qed.
-
- Lemma egcd_coprime_mult (a b c : nat) :
- 0 < a ->
- coprime a b ->
- c * (egcdn a b).1 * a = c %[mod b].
- Proof.
- intros.
- rewrite -mulnA -modnMmr egcd_coprime //.
- rewrite modnS mod0n dvdn1.
- case: eqVneq => [-> |].
- - by rewrite muln0 !modn1.
- - by rewrite muln1.
- Qed.
-
- Lemma sym_pairwiseP {T : Type} (r : T -> T -> bool) (sym:symmetric r) (x0 : T) (xs : seq T) :
- reflect {in gtn (size xs) &, {homo nth x0 xs : i j / i <> j >-> r i j}} (pairwise r xs).
- Proof.
- induction xs; simpl; first by apply (iffP idP).
- apply: (iffP andP).
- - intros [??]?????.
- destruct x; destruct y; simpl.
- + lia.
- + move/(all_nthP x0): H.
- apply.
- rewrite/in_mem/mem/= in H2.
- lia.
- + rewrite sym.
- move/(all_nthP x0): H.
- apply.
- rewrite/in_mem/mem/= in H1.
- lia.
- + rewrite H0 in IHxs.
- inversion IHxs.
- apply H4; trivial.
- lia.
- - intros ?.
- split.
- + apply/(all_nthP x0)=> i ilt.
- move: (H 0 i.+1) => /=.
- apply; rewrite /in_mem/mem/=; lia.
- + inversion IHxs => //.
- elim H1 => x y xlt ylt xny.
- move: (H x.+1 y.+1).
- unfold in_mem, mem in *; simpl in *.
- apply; lia.
- Qed.
-
- Lemma allrel_sym {A:eqType} f (l1 l2: seq A) :
- symmetric f ->
- allrel f l1 l2 = allrel f l2 l1.
- Proof.
- rewrite allrelC=>sym.
- apply eq_allrel => x y.
- by rewrite sym.
- Qed.
-
- Lemma pairwise_perm_sym {A:eqType} f (l1 l2: seq A) :
- symmetric f ->
- perm_eq l1 l2 ->
- pairwise f l1 = pairwise f l2.
- Proof.
- move=> symf.
- move: l1 l2.
- wlog pimp: / forall l1 l2, perm_eq l1 l2 -> pairwise f l1 -> pairwise f l2.
- - apply.
- apply catCA_perm_ind=> l1 l2 l3.
- rewrite !pairwise_cat !allrel_catr.
- move/andP=>[/andP-[rel12 rel13] /andP-[p1 /andP-[rel23 /andP-[p2 p3]]]].
- repeat (try (apply/andP; split)) => //.
- by rewrite allrel_sym.
- - move=> l1 l2 pm.
- apply Bool.eq_bool_prop_intro.
- split =>/Bool.Is_true_eq_true=> HH
- ; apply Bool.Is_true_eq_left
- ; eapply pimp; try apply HH.
- + by [].
- + by rewrite perm_sym.
- Qed.
-
- Lemma pairwise_coprime_perm l l2:
- perm_eq l l2 ->
- pairwise coprime l = pairwise coprime l2.
- Proof.
- intros.
- apply pairwise_perm_sym => // x y.
- apply coprime_sym.
- Qed.
-
- Lemma prodn_filter1 [I : Type] (r : seq I) (F : I -> nat) :
- \prod_(i <- r | F i != 1) F i = \prod_(i <- r) F i.
- Proof.
- induction r.
- - by rewrite !big_nil.
- - rewrite !big_cons -IHr.
- case: eqVneq => /= [->|//].
- lia.
- Qed.
-
- Lemma balanced_chinese_list_mod_inner (l : seq (nat * nat)) :
- (forall p, p \in l -> 0 < p.2) ->
- pairwise coprime (map snd l) ->
- forall p,
- p \in l ->
- \prod_(q <- l | q != p) q.2 * ((p.1 * (egcdn (\prod_(q <- l | q != p) q.2) p.2).1) %% p.2) == p.1 %[mod p.2].
- Proof.
- intros.
- apply /eqP.
- rewrite modnMmr mulnC.
- apply egcd_coprime_mult.
- - rewrite big_seq_cond.
- apply prodn_cond_gt0 => i.
- move/andP => [iinl _].
- by apply H.
- - rewrite (perm_big_AC mulnA mulnC _ (r2:=p :: rem p l)).
- + case: (eqVneq p.2 1) => [->|neq].
- {
- by rewrite coprimen1.
- }
- rewrite -big_filter /= eqxx /=.
- rewrite (_ : (\prod_(i <- [seq i <- rem p l | i != p]) i.2) = (\prod_(i <- (map snd [seq x <- rem p l | x.2 != 1]) | (i != p.2)) i)).
- * rewrite coprime_sym -big_filter.
- apply all_coprime_prod.
- apply pairwise_coprime_cons.
- rewrite (pairwise_perm_sym coprime_sym (perm_map snd (perm_to_rem H1))) in H0.
- revert H0.
- apply subseq_pairwise.
- rewrite /= eqxx.
- rewrite filter_map.
- apply map_subseq.
- eapply subseq_trans.
- -- apply filter_subseq.
- -- apply filter_subseq.
- * rewrite big_map.
- transitivity (\prod_(i <- [seq i <- rem p l | i != p] | i.2 != 1) i.2)
- ; [by rewrite prodn_filter1 |].
- rewrite -big_filter.
- symmetry; rewrite -big_filter.
- rewrite -!filter_predI /predI.
- f_equal.
- apply eq_in_filter => x xin /=.
- case: (eqVneq x.2 1) => /=; [by rewrite Bool.andb_false_r |intros ne1].
- case: (eqVneq x p).
- { move => ->.
- by rewrite eqxx.
- }
- case: (eqVneq x.2 p.2) => //= eqq2.
- rewrite /eq_op/= eqq2 eqxx Bool.andb_true_r => eqq1.
- move: H0.
- rewrite (pairwise_perm_sym coprime_sym (perm_map snd (perm_to_rem H1))) /=.
- move/andP=> [].
- have x2in: x.2 \in [seq i.2 | i <- rem p l] by apply map_f.
- move/allP/(_ x.2 x2in).
- rewrite eqq2 /coprime gcdnn -eqq2 => eqq3.
- by rewrite eqq3/= in ne1.
- + by apply perm_to_rem.
- Qed.
-
- Lemma pairwise_coprime_uniq (l : seq nat) :
- (forall p, p \in l -> 1 < p) ->
- pairwise coprime l ->
- uniq l.
- Proof.
- intros.
- rewrite uniq_pairwise.
- have: (pairwise (fun x y => coprime x y && (1 x y xlt ylt xlty.
- apply H.
- by apply mem_nth.
- }
- apply sub_pairwise.
- intros ???.
- simpl.
- assert (x = y -> false).
- {
- intros.
- rewrite H2 in H1.
- move /andP in H1; destruct H1.
- rewrite /coprime gcdnn in H1.
- move /eqP in H1.
- lia.
- }
- by apply (contra_not_neq H2).
- Qed.
-
- Lemma pairwise_coprime_uniq_pair (l : seq (nat * nat)) :
- (forall p, p \in l -> 1 < p.2) ->
- pairwise coprime (map snd l) ->
- uniq l.
- Proof.
- intros.
- apply (map_uniq (f := snd)).
- apply pairwise_coprime_uniq; trivial.
- intros.
- move /mapP in H1.
- destruct H1.
- rewrite H2.
- by apply H.
- Qed.
-
- Lemma prod_split1 (l : seq (nat * nat)) (p : nat*nat) :
- uniq l ->
- p \in l ->
- \prod_(q<-l) q.2 = p.2 * \prod_(q <- l | q != p) q.2.
- Proof.
- intros.
- rewrite (big_rem_AC mulnA mulnC _ (z := p)) //.
- f_equal.
- rewrite -big_filter.
- symmetry; rewrite -big_filter.
- replace [seq _ <- rem p l | true] with [seq i <- l | i != p]; trivial.
- rewrite rem_filter // filter_predT.
- apply eq_filter.
- by rewrite /predC1 /=.
- Qed.
-
- Lemma balanced_chinese_list_mod_inner_lt (l : seq (nat * nat)) :
- uniq l ->
- (forall p, p \in l -> 0 < p.2) ->
- forall p,
- p \in l ->
- \prod_(q <- l | q != p) q.2 * ((p.1 * (egcdn (\prod_(q <- l | q != p) q.2) p.2).1) %% p.2) < \prod_(q <- l) q.2.
- Proof.
- intros.
- rewrite (prod_split1 H H1) mulnC ltn_pmul2r.
- - apply ltn_pmod.
- by apply H0.
- - rewrite big_seq_cond.
- apply prodn_cond_gt0 => i.
- intros.
- apply H0.
- move /andP in H2.
- tauto.
- Qed.
-
- Lemma sum_list_const {T} (l : list T) (c : nat) :
- \sum_(p <- l) c = (size l) * c.
- Proof.
- induction l.
- - by rewrite big_nil /= mul0n.
- - rewrite big_cons IHl /=.
- lia.
- Qed.
-
- Lemma big_sum_le_const {T:eqType} (l : list T) (F : T -> nat) (c : nat) :
- (forall p, p \in l -> F p <= c) ->
- \sum_(p <- l) F p <= (size l)*c.
- Proof.
- move=> inle.
- rewrite -sum_list_const.
- rewrite !big_seq.
- apply leq_sum => i ini.
- by apply inle.
- Qed.
-
- Lemma big_sum_lt_const {T:eqType} (l : list T) (F : T -> nat) (c : nat) :
- (forall p, p \in l -> F p < c) ->
- 0 < size l ->
- \sum_(p <- l) F p < (size l)*c.
- Proof.
- move=> inle lnnil.
- destruct l.
- - by rewrite !big_nil.
- - rewrite !big_cons /= mulSn.
- have lt1: F s < c.
- {
- apply inle.
- apply mem_head.
- }
- rewrite -(ltn_add2r (size l * c)) in lt1.
- eapply leq_ltn_trans; [|apply lt1].
- rewrite leq_add2l.
- apply big_sum_le_const => x xin.
- apply ltnW.
- apply inle.
- by rewrite in_cons xin orbT.
- Qed.
-
- Lemma balanced_chinese_list_mod_lt (l : seq (nat * nat)) :
- (forall p, p \in l -> 1 < p.2) ->
- pairwise coprime (map snd l) ->
- 0 < (size l) ->
- balanced_chinese_list l < (size l) * \prod_(q <- l) q.2.
- Proof.
- intros.
- apply big_sum_lt_const; trivial.
- intros.
- apply balanced_chinese_list_mod_inner_lt; trivial.
- - by apply pairwise_coprime_uniq_pair.
- - intros.
- apply H in H3.
- lia.
- Qed.
-
- Lemma modn_add0 (m a b : nat) :
- b == 0 %[mod m] ->
- a %% m + b == a %[mod m].
- Proof.
- intros.
- move /eqP in H.
- rewrite modnDml -modnDm H mod0n addn0 modn_mod //.
- Qed.
-
- Lemma modn_mull0 (m a b : nat) :
- a %% m = 0 ->
- (a * b) %% m = 0.
- Proof.
- intros.
- rewrite -(mod0n m) -modnMml H mul0n //.
- Qed.
-
- Lemma balanced_chinese_list_mod1 (l : seq (nat * nat)) :
- (forall p, p \in l -> 1 < p.2) ->
- pairwise coprime (map snd l) ->
- forall p,
- p \in l ->
- balanced_chinese_list l == p.1 %[mod p.2].
- Proof.
- intros.
- rewrite -modn_summ (bigD1_seq p) /= //.
- - have posl: (forall p, p \in l -> 0 < p.2).
- {
- intros ??.
- apply H in H2.
- lia.
- }
- rewrite (eqP (balanced_chinese_list_mod_inner posl H0 H1)).
- apply modn_add0.
- rewrite big1_idem //.
- intros.
- apply modn_mull0.
- rewrite -big_filter.
- rewrite (perm_big_AC mulnA mulnC _ (r2:=p :: rem p [seq i0 <- l | i0 != i])).
- + by rewrite big_cons modnMr.
- + apply perm_to_rem.
- by rewrite mem_filter H1 eq_sym H2.
- - by apply pairwise_coprime_uniq_pair.
- Qed.
-
- Lemma balanced_chinese_list_filter1 (l : seq (nat * nat)) :
- balanced_chinese_list [seq x <- l | x.2 != 1] = balanced_chinese_list l.
- Proof.
- rewrite /balanced_chinese_list.
-
- move: (perm_filterC (fun a => a.2 != 1) l).
- move/permPl => p1.
- rewrite -(perm_big_AC addnA addnC _ _ _ p1).
- rewrite big_cat /=.
-
- have ->: \sum_(i <-[seq x <- l | predC (fun a : nat * nat => a.2 != 1) x])
- \prod_(q <- l | q != i) q.2 * ((i.1 * (egcdn (\prod_(q <- l | q != i) q.2) i.2).1) %% i.2) = 0.
- {
- rewrite big_filter.
- under eq_bigr.
- {
- move=> i /=.
- case: eqVneq => //= -> _.
- rewrite modn1 muln0.
- over.
- }
- by rewrite big_const_seq iter_addn_0 mul0n.
- }
- rewrite addn0 !big_filter.
-
- apply eq_bigr => i ne1.
- f_equal.
- - rewrite -big_filter.
- symmetry.
- rewrite -big_filter -prodn_filter1 -big_filter.
- f_equal.
- rewrite -!filter_predI /predI.
- apply eq_in_filter => x xin /=.
- by rewrite andbC.
- - do 4 f_equal.
- rewrite -big_filter.
- symmetry.
- rewrite -big_filter -prodn_filter1 -big_filter.
- f_equal.
- rewrite -!filter_predI /predI.
- apply eq_in_filter => x xin /=.
- by rewrite andbC.
- Qed.
-
- Lemma balanced_chinese_list_mod (l : seq (nat * nat)) :
- (forall p, p \in l -> 0 < p.2) ->
- pairwise coprime (map snd l) ->
- forall p,
- p \in l ->
- balanced_chinese_list l == p.1 %[mod p.2].
- Proof.
- intros.
- case: (eqVneq p.2 1) => eqq.
- {
- by rewrite eqq !modn1.
- }
- have: balanced_chinese_list [seq x <- l | x.2 != 1] == p.1 %[mod p.2].
- - apply balanced_chinese_list_mod1.
- + move=> nn.
- rewrite mem_filter.
- move/andP => [ne1 nnin].
- move: (H _ nnin) ne1 => /=.
- destruct (nn.2) as [|[]]; lia.
- + move: H0.
- rewrite !pairwise_map.
- apply pairwise_filter.
- + by rewrite mem_filter H1 eqq.
- - by rewrite balanced_chinese_list_filter1.
- Qed.
-
- Lemma chinese_remainder_list_permutation (l l2: list (nat * nat)) :
- pairwise coprime (map snd l) ->
- perm_eq l l2 ->
- let m := \prod_(i <- map snd l) i in
- chinese_list l == chinese_list l2 %[mod m].
- Proof.
- intros co_l perm.
- apply chinese_remainder_list_unique; trivial.
- intros.
- assert (co_l2: pairwise coprime (map snd l2)).
- {
- rewrite (pairwise_coprime_perm (l2:=[seq i.2 | i <- l]))//.
- apply perm_map.
- by rewrite perm_sym.
- }
- move/mapP: H => [px ] in1 ->.
- rewrite (eqP (chinese_remainder_list co_l in1)).
- move: in1.
- rewrite (perm_mem perm)=> in2.
- by rewrite (eqP (chinese_remainder_list co_l2 in2)).
- Qed.
-
- Definition Zp_reduce_r (p q : nat) (a : 'Z_(p*q)) : 'Z_q := inZp a.
- Definition Zp_reduce_l (p q : nat) (a : 'Z_(p*q)) : 'Z_p := inZp a.
- Definition Zp_reduce_pair (p q : nat) (a : 'Z_(p*q)) := (Zp_reduce_l a,
- Zp_reduce_r a).
-
- Lemma modn_plus_const x a b m :
- a = b %[mod m] ->
- x + a = x + b %[mod m].
- Proof.
- intros.
- by rewrite -modnDm H modnDm.
- Qed.
-
- Lemma Zp_reduce_r_is_morphism (p q : nat) :
- 1 < p ->
- 1 < q ->
- rmorphism (@Zp_reduce_r p q).
- Proof.
- intros.
- assert (1 < p*q) by lia.
- assert ((Zp_trunc q).+2 %| (Zp_trunc (p * q)).+2).
- {
- rewrite !Zp_cast //.
- apply dvdn_mull.
- apply dvdnn.
- }
- constructor.
- - intros ??.
- apply val_inj; simpl.
- rewrite modnDm modn_dvdm //.
- apply (@modn_plus_const x (((Zp_trunc (p * q)).+2 - y) %% (Zp_trunc (p * q)).+2)
- ((Zp_trunc q).+2 - y %% (Zp_trunc q).+2) (Zp_trunc q).+2).
- rewrite modn_dvdm //.
- destruct y.
- simpl.
- rewrite !Zp_cast //.
- rewrite Zp_cast // in i.
- clear H2 x.
- rewrite modnB; try lia.
- rewrite modnMl.
- case (boolP (0 < m %% q)); intros; simpl.
- + rewrite mul1n addn0.
- assert (q - m%%q < q) by lia.
- rewrite (modn_small H2) //.
- + rewrite mul0n addn0.
- assert (m%%q = 0) by lia.
- rewrite H2 !subn0 modnn //.
- - constructor.
- + intros ??.
- apply val_inj; simpl.
- rewrite modnMm modn_dvdm // !Zp_cast //.
- + apply val_inj; simpl.
- rewrite modn_dvdm // !Zp_cast //.
- Qed.
-
- Lemma Zp_reduce_l_is_morphism (p q : nat) :
- 1 < p ->
- 1 < q ->
- rmorphism (@Zp_reduce_l p q).
- Proof.
- intros.
- rewrite /Zp_reduce_l mulnC.
- by apply Zp_reduce_r_is_morphism.
- Qed.
-
- Lemma Zp_reduce_pair_is_morphism (p q : nat) :
- 1 < p ->
- 1 < q ->
- rmorphism (@Zp_reduce_pair p q).
- Proof.
- intros.
- destruct (@Zp_reduce_l_is_morphism p q H H0) as [? [? ?]].
- destruct (@Zp_reduce_r_is_morphism p q H H0) as [? [? ?]].
- constructor.
- - intros ??.
- rewrite /Zp_reduce_pair base base0 //.
- - constructor.
- + intros ??.
- rewrite /Zp_reduce_pair m m0 //.
- + rewrite /Zp_reduce_pair e e0 //.
- Qed.
-
- Definition Zp_lift_pair (p q : nat) (r : 'Z_p * 'Z_q) : 'Z_(p*q) :=
- inZp (chinese p q r.1 r.2).
-
- Lemma modn_muln_l x p q :
- (x %% (p * q)) %% p = x %% p.
- Proof.
- symmetry.
- apply /eqP.
- have HH: (x %% (p * q) <= x) by apply leq_mod.
- rewrite eqn_mod_dvd //.
- apply mul_dvdn_l with (d2 := q).
- by rewrite -(eqn_mod_dvd (p * q)) // modn_mod.
- Qed.
-
- Lemma modn_muln_r x p q :
- (x %% (p * q)) %% q = x %% q.
- Proof.
- by rewrite mulnC modn_muln_l.
- Qed.
-
- Lemma Zp_lift_pair_is_morphism (p q : nat) :
- 1 < p ->
- 1 < q ->
- coprime p q ->
- rmorphism (@Zp_lift_pair p q).
- Proof.
- intros.
- assert (1 < p*q) by lia.
- generalize (chinese_remainder H1); intros.
- generalize (chinese_modl H1); intros.
- generalize (chinese_modr H1); intros.
- constructor.
- - intros ??.
- unfold Zp_lift_pair.
- rewrite -inZp_add.
- apply val_inj.
- destruct x, y.
- destruct s, s0, s1, s2.
- rewrite /= !Zp_cast //.
- apply /eqP.
- rewrite H3.
- apply /andP.
- split; apply /eqP.
- + rewrite H4.
- symmetry.
- rewrite -modnDm H4 modnB; [|lia|lia].
- rewrite modn_muln_l modnMr addn0.
- case (boolP (0 < (chinese p q m1 m2 %% p))); simpl; intros.
- * rewrite mul1n H4.
- symmetry.
- rewrite -modnDm !modn_mod.
- apply modn_plus_const.
- rewrite Zp_cast in i1; [|lia].
- rewrite modnB; [|lia|lia].
- rewrite modnn (modn_small i1) addn0.
- case (boolP (0 < m1)); simpl; intros.
- -- by rewrite mul1n.
- -- assert (m1 = 0) by lia.
- by rewrite H6 mul0n !subn0 modnn mod0n.
- * assert (chinese p q m1 m2 %% p = 0) by lia.
- rewrite mul0n H4.
- symmetry.
- rewrite -modnDm !modn_mod.
- apply modn_plus_const.
- rewrite Zp_cast in i1; [|lia].
- rewrite modnB; [|lia|lia].
- rewrite modnn (modn_small i1) addn0.
- case (boolP (0 < m1)); simpl; intros.
- -- rewrite mul1n.
- rewrite H4 (modn_small i1) in H6.
- lia.
- -- by rewrite mul0n.
- + rewrite H5.
- symmetry.
- rewrite -modnDm H5 modnB; [|lia|lia].
- rewrite modn_muln_r modnMl addn0.
- case (boolP (0 < (chinese p q m1 m2 %% q))); simpl; intros.
- * rewrite mul1n H5.
- symmetry.
- rewrite -modnDm !modn_mod.
- apply modn_plus_const.
- rewrite Zp_cast in i2; [|lia].
- rewrite modnB; [|lia|lia].
- rewrite modnn (modn_small i2) addn0.
- case (boolP (0 < m2)); simpl; intros.
- -- by rewrite mul1n.
- -- assert (m2 = 0) by lia.
- by rewrite H6 mul0n !subn0 modnn mod0n.
- * assert (chinese p q m1 m2 %% q = 0) by lia.
- rewrite mul0n H5.
- symmetry.
- rewrite -modnDm !modn_mod.
- apply modn_plus_const.
- rewrite Zp_cast in i2; [|lia].
- rewrite modnB; [|lia|lia].
- rewrite modnn (modn_small i2) addn0.
- case (boolP (0 < m2)); simpl; intros.
- -- rewrite mul1n.
- rewrite H5 (modn_small i2) in H6.
- lia.
- -- by rewrite mul0n.
- - constructor.
- + intros ??.
- unfold Zp_lift_pair.
- rewrite -inZp_mul.
- apply val_inj.
- destruct x, y.
- destruct s, s0, s1, s2.
- rewrite /= !Zp_cast //.
- apply /eqP.
- rewrite H3.
- apply /andP.
- split; apply /eqP; symmetry.
- * by rewrite -modnMm !H4 /= modnMm modn_mod.
- * by rewrite -modnMm !H5 /= modnMm modn_mod.
- + unfold Zp_lift_pair.
- apply val_inj.
- rewrite /= !Zp_cast //.
- apply /eqP.
- rewrite H3.
- apply /andP.
- split; apply /eqP.
- * rewrite H4 !modn_small //.
- * rewrite H5 !modn_small //.
- Qed.
-
- Lemma right_inv_chinese (p q : nat) :
- 1 < p ->
- 1 < q ->
- coprime p q ->
- cancel (@Zp_lift_pair p q) (@Zp_reduce_pair p q).
- Proof.
- intros.
- unfold cancel.
- intros.
- unfold Zp_reduce_pair, Zp_lift_pair.
- destruct x.
- simpl.
- unfold Zp_reduce_l, Zp_reduce_r.
- generalize (chinese_remainder H1); intros.
- generalize (chinese_modl H1); intros.
- generalize (chinese_modr H1); intros.
- assert (1 < p * q) by lia.
- destruct o, o0.
- f_equal.
- - apply val_inj.
- rewrite /= !Zp_cast //.
- rewrite Zp_cast // in i.
- rewrite modn_muln_l H3.
- by rewrite (modn_small i).
- - apply val_inj.
- rewrite /= !Zp_cast //.
- rewrite Zp_cast // in i0.
- rewrite modn_muln_r H4.
- by rewrite (modn_small i0).
- Qed.
-
- Lemma left_inv_chinese (p q : nat) :
- 1 < p ->
- 1 < q ->
- coprime p q ->
- cancel (@Zp_reduce_pair p q) (@Zp_lift_pair p q).
- Proof.
- intros.
- unfold cancel.
- intros.
- unfold Zp_reduce_pair, Zp_lift_pair.
- destruct x.
- unfold Zp_reduce_l, Zp_reduce_r.
- generalize (chinese_remainder H1); intros.
- generalize (chinese_modl H1); intros.
- generalize (chinese_modr H1); intros.
- assert (1 < p * q) by lia.
- apply val_inj.
- simpl.
- rewrite !Zp_cast //.
- rewrite Zp_cast // in i.
- replace m with (m %% (p * q)) at 3.
- - apply /eqP.
- rewrite H2.
- apply /andP.
- split.
- + apply /eqP.
- by rewrite H3 modn_mod.
- + apply /eqP.
- by rewrite H4 modn_mod.
- - by rewrite (modn_small i).
- Qed.
-
- Lemma bijective_reduce_pair (p q : nat)
- (pbig: 1 < p)
- (qbig: 1 < q)
- (cop: coprime p q) :
- bijective (@Zp_reduce_pair p q).
- Proof.
- eapply Bijective.
- - by apply left_inv_chinese.
- - by apply right_inv_chinese.
- Qed.
-
- Lemma bijective_lift_pair (p q : nat)
- (pbig: 1 < p)
- (qbig: 1 < q)
- (cop: coprime p q) :
- bijective (@Zp_lift_pair p q).
- Proof.
- eapply Bijective.
- - by apply right_inv_chinese.
- - by apply left_inv_chinese.
- Qed.
-
-End chinese.
-
-(* order of 3 mod 2^(n+2) = 2^n *)
-(* show 3^(2^n) <> 1 mod 2^(n+3) *)
-
-From mathcomp Require Import ssrnat.
-
-Lemma n_n1_even j :
- exists k,
- j * (j + 1) = k.*2.
-Proof.
- assert (~~ odd(j * (j + 1))).
- {
- replace (j+1) with (S j) by lia.
- rewrite oddM oddS.
- by case: (odd j).
- }
- apply even_halfK in H.
- exists ((j * (j + 1)) ./2).
- by rewrite H.
-Qed.
-
-Lemma modn_sub i j m :
- i >= j ->
- (i == j %[mod m]) = (i - j == 0 %[mod m]).
-Proof.
- move/eqn_mod_dvd->.
- by rewrite mod0n.
-Qed.
-
-Lemma modn_sub_iff i j m :
- i >= j ->
- i = j %[mod m] <-> i - j = 0 %[mod m].
-Proof.
- move/modn_sub=>eqq.
- split; move/eqP
- ; [rewrite eqq | rewrite -eqq]
- ; by move/eqP.
-Qed.
-
-
-Lemma subn_sqr_1 (x : nat) :
- x^2-1 = (x + 1) * (x - 1).
-Proof.
- replace (x^2-1) with (x^2-1^2) by lia.
- by rewrite mulnC -subn_sqr.
-Qed.
-
-Lemma ord_odd_pow_2 j n :
- (2*j+1)^(2^n.+1) = 1 %[mod 2^(n.+3)].
-Proof.
- induction n.
- - rewrite expn1 expnS expn1.
- replace ((2 * j + 1) * (2 * j + 1)) with (4*(j*(j+1)) + 1) by lia.
- destruct (n_n1_even j).
- rewrite H /=.
- replace (2^3) with 8 by lia.
- replace (4 * (x.*2)) with (8 * x) by lia.
- by rewrite -modnDm modnMr modnDmr.
- - rewrite expnS (mulnC _ (2^n.+1)) expnM (expnS _ n.+3).
- rewrite modn_sub_iff; [|lia].
- rewrite subn_sqr_1.
- rewrite modn_sub_iff in IHn; [|lia].
- assert (exists k,
- 2 * k = ((2 * j + 1) ^ 2 ^ n.+1 + 1)).
- {
- assert (~~ odd ((2 * j + 1) ^ 2 ^ n.+1 + 1)).
- {
- rewrite oddD oddX oddD.
- replace (2 *j) with (j.*2) by lia.
- rewrite odd_double /=.
- lia.
- }
- apply even_halfK in H.
- exists (((2 * j + 1) ^ 2 ^ n.+1 + 1)./2).
- rewrite -H.
- lia.
- }
- destruct H.
- rewrite -H -mulnA -muln_modr.
- replace 0 with (2*0) at 7 by lia.
- rewrite -muln_modr.
- f_equal.
- by rewrite -modnMm IHn modnMm muln0.
- Qed.
-
-Lemma ord_odd_pow_2' j n :
- odd j ->
- j^(2^n.+1) = 1 %[mod 2^(n.+3)].
-Proof.
- intros.
- generalize (ord_odd_pow_2 (j./2) n); intros.
- generalize (odd_double_half j); intros.
- replace (2 * j./2) with (j./2.*2) in H0 by lia.
- rewrite H /= addnC in H1.
- by rewrite -H1.
-Qed.
-
-Lemma iotaSn0 m n : n != 0 ->
- iota m n = m :: iota m.+1 n.-1.
-Proof.
- case: n => //=.
-Qed.
-
-Lemma index_iotaSn0 m n : m < n ->
- index_iota m n = m :: index_iota m.+1 n.
-Proof.
- rewrite /index_iota=> mltn.
- rewrite iotaSn0; try lia.
- do 2 f_equal.
- lia.
-Qed.
-
-(*
-Lemma add4_pow2_mod j n :
- (j + 4)^(2 ^n) = j^(2^n) + (2^n.+2)*j^(2^n-1) %[mod 2^n.+3].
-Proof.
- rewrite (Pascal j 4 (2^n)) /=.
- move: (@big_mkord _ 0 addn (2 ^ n).+1 predT (fun i => 'C(2 ^ n, i) * (j ^ (2 ^ n - i) * 4 ^ i)))=> /= <-.
- rewrite index_iotaSn0 // big_cons.
- rewrite index_iotaSn0 ?big_cons; [| lia].
- rewrite expn0 expn1 muln1 subn0 bin0 bin1 mul1n addnA.
- rewrite (mulnC _ 4) mulnA.
- replace (2^n*4) with (2^n.+2) by (rewrite !expnS; lia).
- assert (\sum_(2 <= j0 < (2 ^ n).+1) 'C(2 ^ n, j0) * (j ^ (2 ^ n - j0) * 4 ^ j0) = 0 %[mod 2^n.+3]).
- {
- rewrite -modn_summ.
- rewrite (eqP (_ : \sum_( _ <= _ < _ ) _ == 0)) //.
- rewrite (big_nat_widenl _ 0) //.
- move: (@big_mkord _ 0 addn (2 ^ n).+1 (fun i => (andb true (leq (S (S O)) i))) (fun i => ('C(2 ^ n, i) * (j ^ (2 ^ n - i) * 4 ^ i))
- %% 2 ^ n.+3)) => /= ->.
- rewrite sum_nat_eq0.
- apply/forallP => x.
- apply/implyP => xbig.
- assert (forall k, k < n -> ('C(2 ^ n, 2^k.+1) * 4 ^ 2^k.+1) %% 2 ^ n.+3 == 0)%N.
- {
- intros.
- assert (exists q, 'C(2^n, 2^k.+1) = q * 2^(n-k.+1)).
- {
- admi t.
- }
- destruct H0.
- rewrite H0.
- replace 4 with (2^2) by lia.
- rewrite -expnM -expnS.
- replace (x0 * 2 ^ (n - k.+1) * 2 ^ 2 ^ k.+2) with
- (x0 * 2^ (n - k.+1 + (2^k.+2))).
- - assert (2^k.+2 >= k.+4).
- {
- clear H H0.
- induction k; trivial.
- assert (k.+4 < 2*2^k.+2) by lia.
- by rewrite expnS.
- }
- rewrite -modnMm.
- assert (exists q, (2 ^ (n - k.+1 + 2 ^ k.+2) = q * 2^n.+3)).
- {
- admi t.
- }
- destruct H2.
- by rewrite H2 modnMl muln0 mod0n.
- - by rewrite -mulnA expnD.
- }
- assert (('C(2 ^ n, x) * 4 ^ x) %% 2 ^ n.+3 == 0)%N.
- {
- (* 2^(n-1) | 'C(2^n, 2), ok for x = 2 *)
- (* 2^n | 'C(2^n,3), ok for x = 3 *)
- (* 2^(n-2) | 'C(2^n, 4), ok for x = 4 *)
- (* enough to prove for x = 2^k, since sucessive values are better *)
- admi t.
- }
- by rewrite (mulnC _ (4 ^ x)) mulnA -modnMm (eqP H0) mul0n mod0n.
- }
- by rewrite -modnDmr -modnDmr H !mod0n addn0.
-
-Admi tted.
-
-Lemma ord_pow_2_odd j n :
- odd j ->
- j ^ (2^n) = 1 %[mod 2^n.+3] ->
- (j + 4)^(2^n) <> 1 %[mod 2^n.+3].
-Proof.
- intros.
- rewrite add4_pow2_mod -modnDm H0 modnDm addnC modn_sub_iff; [|lia].
- replace (2 ^ n.+2 * j ^ (2 ^ n - 1) + 1 - 1) with
- (2 ^ n.+2 * j ^ (2 ^ n - 1)) by lia.
- rewrite (expnS _ n.+2) (mulnC 2 _).
- replace 0 with (2^n.+2 * 0) at 6 by lia.
- rewrite -!muln_modr mod0n muln0.
- apply /eqP.
- rewrite muln_eq0.
- apply /norP.
- split; [lia |].
- by rewrite modn2 oddX H orbT.
-Qed.
-*)
-
-(* https://math.stackexchange.com/questions/459815/the-structure-of-the-group-mathbbz-2n-mathbbz *)
-
-
-Lemma mod_mul_mul_0 a b m1 m2 :
- a == 0 %[mod m1] && (b == 0 %[mod m2]) ->
- a * b == 0 %[mod m1 * m2].
-Proof.
- do 3 (rewrite eqn_mod_dvd; [|lia]).
- rewrite !subn0.
- move /andP=>[diva divb].
- by rewrite dvdn_mul.
-Qed.
-
-Lemma mod_mul_mul_0_alt a b m1 m2 :
- a = 0 %[mod m1] /\ (b = 0 %[mod m2]) ->
- a * b = 0 %[mod m1 * m2].
-Proof.
- move=>[/eqP? /eqP?].
- by apply/eqP/mod_mul_mul_0/andP.
-Qed.
-
-Lemma mod_pow2_sqr_aux a b n :
- b <= a ->
- a = b %[mod 2^n.+1] ->
- a^2 = b^2 %[mod 2^n.+2].
-Proof.
- intros.
- rewrite modn_sub_iff // in H0.
- rewrite modn_sub_iff.
- - rewrite subn_sqr.
- rewrite expnS mulnC.
- apply mod_mul_mul_0_alt.
- split; trivial.
- rewrite -modn_sub_iff // expnS in H0.
- lia.
- - by rewrite leq_sqr.
- Qed.
-
-Lemma mod_pow2_sqr a b n :
- a = b %[mod 2^n.+1] ->
- a^2 = b^2 %[mod 2^n.+2].
-Proof.
- case (boolP (b <= a)); intros.
- - by apply mod_pow2_sqr_aux.
- - symmetry.
- symmetry in H.
- apply mod_pow2_sqr_aux; lia.
- Qed.
-
-Lemma ord_5_pow_2 n :
- 5 ^ (2 ^ n) = 1 + 2^n.+2 %[mod 2^n.+3].
-Proof.
- induction n.
- - lia.
- - rewrite expnS mulnC expnM.
- apply mod_pow2_sqr in IHn.
- rewrite IHn.
- generalize (expnD (1 + 2^n.+2) 1 1); intros.
- rewrite H !expn1.
- replace ((1 + 2 ^ n.+2) * (1 + 2 ^ n.+2)) with
- (1 + 2 * (2^n.+2) + (2^n.+2)*(2^n.+2)) by lia.
- rewrite -expnD.
- assert (2^(n.+2 + n.+2) = 0 %[mod 2^n.+4]).
- {
- replace (n.+2 + n.+2) with ((2 * n).+4) by lia.
- rewrite !expnS -!muln_modr.
- replace (2^(2*n) %% 2^n) with 0.
- - rewrite muln0 mod0n //.
- - rewrite mulnC expnM.
- generalize (expnD (2 ^ n) 1 1); intros.
- by rewrite H0 !expn1 -modnMm modnn muln0 mod0n.
- }
- by rewrite -modnDm H0 mod0n addn0 modn_mod (expnS _ (n.+2)).
- Qed.
-
-Lemma add_exp_mod_p a b p :
- prime p ->
- (a + b)^p = a^p + b^p %[mod p].
-Proof.
- move=> pprime.
- rewrite expnDn.
- move: (prime_gt0 pprime).
- case: p pprime; [lia |]=> p pprime _.
- rewrite big_ord_recr big_ord_recl /=.
- rewrite bin0 binn subn0 subnn !expn0 !mul1n muln1.
- rewrite addnC addnA.
- rewrite -modnDmr -modn_summ.
- suff/eqP->: \sum_(i < p) (('C(p.+1, bump 0 i) * (a ^ (p.+1 - bump 0 i) * b ^ bump 0 i)) %% p.+1) == 0
- by rewrite mod0n addn0 addnC.
- rewrite sum_nat_eq0.
- apply/forallP=> k.
- apply/implyP=> _.
- rewrite -modnMml (eqP (_ : 'C(p.+1, bump 0 k) == 0 %[mod p.+1])).
- - by rewrite mod0n.
- - rewrite (eqn_mod_dvd p.+1) // subn0 prime_dvd_bin //.
- rewrite /bump.
- destruct k; simpl; lia.
-Qed.
-
-Lemma iffbP {x y : bool}: reflect (x <-> y) (x == y).
-Proof.
- case: x y; case
- ; rewrite eqE /= /eqb /=
- ; constructor
- ; firstorder.
-Qed.
-
-Lemma iffEq {x y : bool}: (x <-> y) <-> (x = y).
-Proof.
- case: x y; case; firstorder.
- symmetry; firstorder.
-Qed.
-
-Lemma prime_pow_dvd_aux k p n:
- k <= p^n ->
- forall j,
- j <= n ->
- (p ^ j %| k) = (p^j %| (p^n - k)).
-Proof.
- intros kle j jlt.
- assert (p^j %| p^n).
- {
- by apply dvdn_exp2l.
- }
- apply iffEq.
- split; intros.
- - rewrite dvdn_sub //.
- - by rewrite -(dvdn_subr (m := p^n)).
-Qed.
-
-Lemma prime_pow_dvd j k p n :
- prime p ->
- ~~ (p %| j) ->
- p^n.+1 %| j * k ->
- p^n.+1 %| k.
-Proof.
- move=> pprime pndivj.
- induction n.
- - rewrite !expn1 Euclid_dvdM //.
- case/orP => // eqq1.
- by rewrite eqq1 in pndivj.
- - intros.
- assert (p^n.+1 %| j * k).
- {
- rewrite expnS in H.
- by apply mul_dvdn_r in H.
- }
- generalize (divnK (IHn H0)); intros.
- clear IHn H0.
- rewrite -H1 expnS dvdn_pmul2r.
- + rewrite -H1 mulnA expnS dvdn_pmul2r in H.
- * rewrite Euclid_dvdM // in H.
- move /orP in H.
- destruct H; trivial.
- by move /negP in pndivj.
- * apply prime_gt1 in pprime; lia.
- + apply prime_gt1 in pprime; lia.
- Qed.
-
-
-Lemma expn_boundr a b c :
- 1 < a ->
- a ^ b <= c ->
- b <= c.
-Proof.
- move=> abig able.
- rewrite -(@leq_exp2l a) // (leq_trans able) //.
- by rewrite ltnW // ltn_expl.
-Qed.
-
-Lemma max_prime_pow_dvd j p :
- 1 < p ->
- 0 < j ->
- {k | (p^k %| j) && ~~ (p^k.+1 %| j)}.
-Proof.
- move=> pbig jnneg.
- have exP: exists i : nat, [pred k | p ^ k %| j] i.
- {
- exists 0.
- by rewrite /= expn0 dvd1n.
- }
-
- have bounded: forall i : nat, [pred k | p ^ k %| j] i -> i <= j.
- {
- move=> i /=.
- move/(dvdn_leq jnneg).
- by apply expn_boundr.
- }
-
- exists (ex_maxn exP bounded).
- case: ex_maxnP=> i /= divi ub.
- rewrite divi /=.
- apply/negP.
- move/ub.
- by rewrite ltnn.
-Qed.
-
-(*
-Lemma prime_dvd_pow_m1_bin k p n : prime p -> k < p^n.+1 ->
- ~~ (p %| 'C(p^n.+1-1, k)).
-Proof.
- intros.
- apply /negP.
- revert H0.
- induction k.
- - rewrite bin0.
- intros ??.
- apply prime_gt1 in H.
- apply dvdn_leq in H1; lia.
- - intros.
- assert (k < p^n.+1) by lia.
- specialize (IHk H1).
- generalize (mul_bin_left (p^k-1) k); intros.
- intros ?.
- generalize (prime_gt1 H); intros.
- replace (p^k-1-k) with (p^k-k.+1) in H2.
- + assert (0 < k.+1) by lia.
- destruct (max_prime_pow_dvd H4 H5).
- assert (0< p^k -k.+1).
- {
- admi t.
- }
- destruct (max_prime_pow_dvd H4 H6).
- assert (x = x0).
- {
- admi t.
- }
- move /andP in i.
- move /andP in i0.
- destruct i; destruct i0.
- admi t.
- + clear IHk H2 H3; lia.
-Admi tted.
-*)
-
-Lemma prime_power_dvd_mul_helper p k a b :
- prime p ->
- p ^ k %| a * b ->
- exists j, (j<=k) && (p^j %| a) && (p^(k-j) %| b).
-Proof.
- intros.
- induction k.
- - exists 0.
- by rewrite subn0 !expn0 !dvd1n.
- - assert (p^k %| a*b).
- {
- rewrite expnS in H0.
- by apply mul_dvdn_r in H0.
- }
- specialize (IHk H1).
- destruct IHk.
- move /andP in H2.
- destruct H2.
- move /andP in H2.
- destruct H2.
- generalize (divnK H3); intros.
- generalize (divnK H4); intros.
- rewrite -H5 -H6 expnS in H0.
- assert (forall k, 0 < p^k).
- {
- intros.
- rewrite expn_gt0.
- apply prime_gt0 in H.
- apply /orP.
- tauto.
- }
- replace (a %/ p ^ x * p ^ x * (b %/ p ^ (k - x) * p ^ (k - x))) with
- (a %/ p ^ x * (b %/ p ^ (k - x) * p ^ k)) in H0.
- + rewrite mulnA in H0.
- rewrite dvdn_pmul2r // in H0.
- rewrite (Euclid_dvdM _ _ H) in H0.
- move /orP in H0.
- destruct H0.
- * exists (x.+1).
- apply /andP; split.
- -- apply /andP; split; trivial.
- rewrite -H6 expnS dvdn_pmul2r //.
- -- replace (k.+1-x.+1) with (k-x); trivial.
- * exists x.
- apply /andP; split; trivial.
- -- apply /andP; split; trivial.
- lia.
- -- rewrite -H5.
- replace (k.+1-x) with (1 + (k - x)).
- ++ rewrite expnD expn1 dvdn_pmul2r //.
- ++ clear H1 H3 H4 H5 H0.
- lia.
- + clear H1 H3 H4 H5 H6 H0.
- replace (p^k) with (p^x * p^(k-x)); try lia.
- rewrite -expnD.
- f_equal.
- lia.
- Qed.
-
-Lemma prime_pow_dvd_gen' j k p e n :
- prime p ->
- e <= n ->
- (p^e %| j) ->
- ~~ (p^e.+1 %| j) ->
- p^n.+1 %| j * k ->
- p^(n.+1-e) %| k.
-Proof.
- intros p_prim ele divj notdivj divjk.
- generalize (divnK divj); intros.
- rewrite -H in divjk.
- rewrite mulnC mulnA in divjk.
- replace (p^n.+1) with (p^(n.+1-e) * p^e) in divjk.
- - rewrite dvdn_pmul2r in divjk.
- assert (~~ (p %| (j %/ p ^ e))).
- {
- apply /negP.
- intros ?.
- apply prime_gt0 in p_prim.
- assert (0 < p^e).
- {
- rewrite expn_gt0.
- apply /orP.
- tauto.
- }
- rewrite -(dvdn_pmul2r H1) in H0.
- rewrite divnK // in H0.
- rewrite expnS in notdivj.
- move /negP in notdivj.
- tauto.
- }
- + replace (n.+1-e) with ((n-e).+1) in divjk.
- * rewrite mulnC in divjk.
- generalize (prime_pow_dvd p_prim H0 divjk); intros.
- replace (n.+1-e) with ((n-e).+1); trivial.
- clear divj notdivj H H0 divjk H1; lia.
- * clear divj notdivj H H0 divjk; lia.
- + rewrite expn_gt0.
- apply prime_gt0 in p_prim.
- apply /orP.
- tauto.
- - rewrite -expnD.
- f_equal.
- clear divj notdivj H divjk; lia.
- Qed.
-
-Lemma prime_pow_dvd_gen j k p e n :
- prime p ->
- e <= n.+1 ->
- (p^e %| j) ->
- ~~ (p^e.+1 %| j) ->
- p^n.+1 %| j * k ->
- p^(n.+1-e) %| k.
-Proof.
- intros p_prim ele divj notdivj divjk.
- case (boolP (e <= n)).
- - intros.
- by apply prime_pow_dvd_gen' with (j := j).
- - intros.
- assert (e = n.+1) by lia.
- by rewrite H subnn expn0 dvd1n.
- Qed.
-
-(*
-Lemma prime_power_dvd_mul p k a b :
- prime p ->
- p ^ k %| a * b = [exists j:(ordinal (k.-1.+1)), (p^j %| a) && (p^(k-j) %| b)].
-Proof.
-(* move=> pprime.
- case: (eqVneq a 0)=> [-> | neqq].
- - admi t.
- - case: (@max_prime_pow_dvd a p).
- + by apply prime_gt1.
- + lia.
- + move => x /andP-[pdiva pndiva].
- have: p^(k-x) %| b = (p ^ k %| a * b).
- {
- repeat case: dvdnP => //.
- - move=> HH1 [j eqq1]; subst.
- elim HH1.
- move: pdiva.
- move/dvdnP=> [jj eqq2]; subst.
- exists (jj * j).
- rewrite -!mulnA.
- f_equal.
- rewrite mulnC -!mulnA.
- f_equal.
- rewrite -expnD.
- f_equal.
- rewrite subnK //.
-
- -
- }
-
- case: dvdnP.
- apply/iffEq.
- split=> HH.
- * have pdvib: p^(x-k) %| b.
- {
- Search (_ * _ %| _ * _).
-
- move/andP.
-
- case: dvdnP=> HH; symmetry.
- - destruct HH as [j eqq].
-
-
-
-
- induction k => /=.
- - rewrite expn0 dvd1n.
- symmetry.
- apply/existsP.
- simpl.
- exists ord0.
- by rewrite /= expn0 !dvd1n.
- - rewrite expnS.
- apply/iffEq.
- split=> HH.
- + have: p ^ k %| a * b by apply mul_dvdn_r in HH.
- rewrite IHk.
- move/existsP=>[j ]/andP-[diva divb].
- apply/existsP.
- case_eq (dvdn (p^j.+1) a) => eqq1.
- * have ordpf: j.+1 < k.+1.
- {
- admi t.
- }
- exists (Ordinal ordpf).
- by rewrite /= eqq1 /= subSS.
- * have ordpf: j < k.+1.
- {
- admi t.
- }
- exists (Ordinal ordpf).
- rewrite /= diva /=.
-
-
- dvdn_pmul2r: forall [p d m : nat], 0 < p -> (d * p %| m * p) = (d %| m)
- dvdn_pmul2l: forall [p d m : nat], 0 < p -> (p * d %| p * m) = (d %| m)
- *)
-
-Admi tted.
-*)
-
-Lemma prime_pow_dvd_bin_full j k p n :
- prime p -> 0 < k < p^n -> 0 < n ->
- p^j %| k ->
- ~~ (p^j.+1 %| k) ->
- p^(n-j) %| 'C(p^n, k).
-Proof.
- intros.
- have HH: (p^n %| k * 'C(p^n,k)).
- {
- destruct k; trivial.
- by rewrite -mul_bin_diag dvdn_mulr.
- }
- assert (j < n).
- {
- assert (0 < k) by lia.
- generalize (dvdn_leq H4 H2); intros.
- clear H2 H3 HH.
- apply ltn_pexp2l with (m := p); try lia.
- by apply prime_gt0.
- }
- destruct n; trivial.
- apply prime_pow_dvd_gen with (j := k); trivial.
- lia.
-Qed.
-
-Lemma prime_dvd_pow_bin k p n :
- prime p ->
- 0 < k < p ->
- p^n.+1 %| 'C(p^n.+1, k).
-Proof.
- intros.
- generalize (mul_bin_down (p^n.+1) k); intros.
- have HH: (~~ (p %| p^n.+1-k)).
- {
- rewrite dvdn_subr.
- - apply/negP.
- move/dvdn_leq.
- lia.
- - apply (@leq_trans (k ^ n.+1)).
- + rewrite -[k]expn1 leq_pexp2l //.
- lia.
- + rewrite leq_exp2r //.
- lia.
- - by rewrite dvdn_exp.
- }
- assert (p^n.+1 %| (p^n.+1-k) * 'C(p^n.+1,k)).
- {
- rewrite -H1.
- apply dvdn_mulr.
- apply dvdnn.
- }
- rewrite Gauss_dvdr // in H2.
- rewrite coprimeXl //.
- by rewrite prime_coprime.
-Qed.
-
-(*
-Lemma add_exp_mod_exp_p p k :
- prime p ->
- odd p ->
- (1 + p)^(p^k) = 1 %[mod p^k.+1].
-Proof.
- intros.
- rewrite expnDn.
- rewrite big_ord_recl /= bin0 exp1n !mul1n expn0.
- rewrite -modnDmr -modn_summ.
- suff/eqP-> : \sum_(i < p ^ k) 'C(p ^ k, bump 0 i) * (1 ^ (p ^ k - bump 0 i) * p ^ bump 0 i) %% (p ^ k.+1) == 0
- by rewrite mod0n addn0.
-
- rewrite sum_nat_eq0.
- apply/forallP=> i.
- apply/implyP=> _.
- rewrite exp1n mul1n.
- rewrite /bump /=.
-
-
-
- (*
- rewrite bin_ffactd.
- *)
-
-Admi tted.
-
-
-
-Lemma ord_p1_pow_p p n :
- prime p ->
- odd p ->
- (1 + p)^(p^n) = 1 + p^n.+1 %[mod p^n.+2].
-Proof.
- intros.
- induction n.
- - by rewrite expn0 !expn1.
- - rewrite expnS expnS.
- apply (f_equal (fun z => z^p)) in IHn.
- Admi tted.
-
- *)
-
-Lemma ord_5_pow_2_neq n :
- 5^(2^n) <> 1 %[mod 2^n.+3].
-Proof.
- rewrite ord_5_pow_2 !expnS !modn_small; lia.
-Qed.
-
-Lemma ord_5_pow_2_neq_m1 n :
- 5^(2^n) <> 2^n.+3-1 %[mod 2^n.+3].
-Proof.
- rewrite ord_5_pow_2 !expnS; lia.
-Qed.
-
-Lemma ord_pow_gcd b e1 e2 n :
- b^e1 = 1 %[mod n] ->
- b^e2 = 1 %[mod n] ->
- b^(gcdn e1 e2) = 1 %[mod n].
-Proof.
- intros.
- destruct e2.
- - by rewrite gcdn0.
- - assert (0 < e2.+1) by lia.
- destruct (egcdnP e1 H1).
- apply (f_equal (fun z => z^kn %% n)) in H.
- rewrite !modnXm -expnM mulnC exp1n in H.
- apply (f_equal (fun z => z^km %% n)) in H0.
- by rewrite !modnXm -expnM mulnC e expnD exp1n -modnMm H modnMm mul1n gcdnC in H0.
- Qed.
-
-From mathcomp Require Import poly zmodp.
-Local Open Scope ring_scope.
-
-Lemma ord_pow2' (n : nat) (b : 'Z_(2^n.+3)):
- b^+(2^n.+1) = 1 :> 'Z_(2^n.+3) ->
- b^+(2^n) <> 1 :> 'Z_(2^n.+3) ->
- (2^(S n)).-primitive_root b.
-Proof.
- intros.
- by apply @two_pow_prim_root_alt.
-Qed.
-
-Lemma zp_m1_neq1 (n : nat) :
- n > 2 ->
- -1 <> 1 :> 'Z_n.
-Proof.
- intros.
- injection; unfold Zp_trunc; simpl.
- replace (n.-2.+2) with n by lia.
- have /modn_small->: (1 < n)%N by lia.
- have /modn_small->: (n-1 < n)%N by lia.
- lia.
-Qed.
-
-Lemma unit_pow_2_Zp (n : nat) (b : 'Z_(2^n.+1)) :
- b \is a unit <->
- odd b.
-Proof.
- have/(unitZpE b): (2^n.+1 > 1).
- {
- rewrite !expnS; lia.
- }
- rewrite (_ : (b%:R) = b) ?natr_Zp // => ->.
- rewrite -coprimen2 coprime_sym coprime_pexpr; lia.
-Qed.
-
-Lemma unit_pow_2_Zp' (n : nat) (b : {unit 'Z_(2^n.+1)}) :
- odd (val b).
-Proof.
- by rewrite -unit_pow_2_Zp ?(valP b).
-Qed.
-
-
-Lemma ord_5_pow_2_Zp' n :
- inZp (5 ^ (2^n)) = inZp (1 + 2^n.+2) :> 'Z_(2^n.+3).
-Proof.
- generalize (ord_5_pow_2 n); intros.
- rewrite /inZp.
- apply /eqP.
- rewrite /eq_op /=.
- rewrite Zp_cast; [| rewrite !expnS; lia].
- by apply /eqP.
-Qed.
-
-Lemma ord_5_pow_2_Zp n :
- inZp 5 ^+ (2^n) = inZp (1 + 2^n.+2) :> 'Z_(2^n.+3).
-Proof.
- rewrite -ord_5_pow_2_Zp'.
- by rewrite inZp_exp.
-Qed.
-
-Lemma ord_5_pow_2_Zp_1 n :
- inZp 5 ^+ (2^n.+1) = 1 :> 'Z_(2^n.+3).
-Proof.
- assert (odd5:odd 5) by by [].
- move: (ord_odd_pow_2' n odd5)=> b2n1_1.
- rewrite -inZp_exp.
- apply: val_inj => /=.
- rewrite Zp_cast; try lia.
- rewrite !expnS; lia.
-Qed.
-
-Lemma ord_3_pow_2_Zp_1 n :
- inZp 3 ^+ (2^n.+1) = 1 :> 'Z_(2^n.+3).
-Proof.
- assert (odd3:odd 3) by by [].
- move: (ord_odd_pow_2' n odd3)=> b2n1_1.
- rewrite -inZp_exp.
- apply: val_inj => /=.
- rewrite Zp_cast; try lia.
- rewrite !expnS; lia.
-Qed.
-
-Lemma primitive_5_pow2 n :
- let b5 : 'Z_(2^n.+3) := inZp 5 in
- (2^n.+1).-primitive_root b5.
-Proof.
- apply ord_pow2'.
- - apply ord_5_pow_2_Zp_1.
- - rewrite ord_5_pow_2_Zp.
- intros ?.
- apply (f_equal val) in H.
- simpl in H.
- rewrite Zp_cast in H; [|rewrite !expnS; lia].
- rewrite modn_small in H; [|rewrite !expnS; lia].
- rewrite modn_small in H; [|rewrite !expnS; lia].
- lia.
-Qed.
-
-Lemma m1_neq_pow5_mod2n (n : nat) :
- let b5 : 'Z_(2^n.+3) := inZp 5 in
- not (exists k, b5^+k = -1).
-Proof.
- generalize (primitive_5_pow2 n); intros.
- simpl in H.
- generalize (two_pow_prim_root_m1_alt b5 n H); intros.
- apply H0.
- - apply zp_m1_neq1.
- rewrite !expnS; lia.
- - rewrite ord_5_pow_2_Zp.
- rewrite /opp /= /Zp_opp.
- intros ?.
- apply (f_equal val) in H1.
- simpl in H1.
- rewrite Zp_cast in H1; [|rewrite !expnS; lia].
- rewrite modn_small in H1; [|rewrite !expnS; lia].
- rewrite modn_small in H1; [|rewrite !expnS; lia].
- rewrite modn_small in H1; [|rewrite !expnS; lia].
- rewrite !expnS in H1; lia.
-Qed.
-
-Lemma m1_neq_pow5_mod2n_gen (n : nat) :
- let b5 : 'Z_(2^n.+2) := inZp 5 in
- not (exists k, b5^+k = -1).
-Proof.
- destruct n.
- - simpl.
- set b5 : 'Z_(2^0.+2) := inZp 5.
- unfold not.
- intros.
- destruct H.
- assert (b5 = 1).
- {
- apply val_inj; simpl.
- rewrite /Zp_trunc /=.
- lia.
- }
- rewrite H0 expr1n in H.
- rewrite /opp /= /Zp_opp in H.
- apply (f_equal val) in H.
- simpl in H.
- rewrite Zp_cast in H; lia.
- - apply m1_neq_pow5_mod2n.
-Qed.
-
- Lemma pow_3_5_pow_2 n :
- 3^(2^n.+1) = 5^(2^n.+1) %[mod 2^n.+4].
- Proof.
- induction n.
- - lia.
- - symmetry.
- rewrite modn_sub_iff; [|rewrite leq_exp2r; lia].
- rewrite expnS !(mulnC 2%N _) !expnM subn_sqr.
- symmetry in IHn.
- rewrite modn_sub_iff in IHn; [|rewrite leq_exp2r; lia].
- rewrite (expnS _ n.+4) (mulnC 2%N _).
- rewrite mod_mul_mul_0_alt; trivial.
- split; trivial.
- rewrite modn2 mod0n oddD !oddX.
- lia.
- Qed.
-
- Lemma ord_3_pow_2_neq n :
- 3^(2^n) <> 1 %[mod 2^n.+3].
- Proof.
- destruct n.
- - lia.
- - rewrite pow_3_5_pow_2.
- apply ord_5_pow_2_neq.
- Qed.
-
- Lemma ord_3_pow_2_neq_m1 n :
- 3^(2^n) <> 2^n.+3-1 %[mod 2^n.+3].
- Proof.
- destruct n.
- - lia.
- - rewrite pow_3_5_pow_2.
- apply ord_5_pow_2_neq_m1.
- Qed.
-
- Lemma primitive_3_pow2 n :
- let b3 : 'Z_(2^n.+3) := inZp 3 in
- (2^n.+1).-primitive_root b3.
- Proof.
- apply ord_pow2'.
- - apply ord_3_pow_2_Zp_1.
- - generalize (@ord_3_pow_2_neq n); intros.
- unfold one; simpl.
- unfold Zp1.
- intros ?.
- rewrite -inZp_exp in H0.
- apply (f_equal val) in H0.
- simpl in H0.
- rewrite Zp_cast in H0; [|rewrite !expnS; lia].
- tauto.
- Qed.
-
- Lemma m1_neq_pow3_mod2n (n : nat) :
- let b3 : 'Z_(2^n.+3) := inZp 3 in
- not (exists k, b3^+k = -1).
-Proof.
- generalize (primitive_3_pow2 n); intros.
- simpl in H.
- generalize (two_pow_prim_root_m1_alt b3 n H); intros.
- apply H0.
- - apply zp_m1_neq1.
- rewrite !expnS; lia.
- - generalize (@ord_3_pow_2_neq_m1 n); intros.
- unfold opp; simpl.
- unfold Zp_opp.
- intros ?.
- unfold b3 in H2.
- rewrite -inZp_exp in H2.
- apply (f_equal val) in H2.
- simpl in H2.
- rewrite Zp_cast in H2; [|rewrite !expnS; lia].
- rewrite H2 in H1.
- clear H0 H2.
- rewrite modn_small in H1.
- + rewrite modn_small in H1; [|rewrite !expnS; lia].
- rewrite modn_small in H1; [|rewrite !expnS; lia].
- tauto.
- + rewrite modn_small; [|rewrite !expnS; lia].
- lia.
-Qed.
-
-From mathcomp Require Import finset eqtype finalg.
-From mathcomp Require Import fingroup.quotient.
-Section two_pow_units.
-
- Import GroupScope.
-Lemma ord_unit_pow_2_Zp (n : nat) (b : {unit 'Z_(2^n.+3)}) :
- b ^+ (2^n.+1) = 1.
-Proof.
- move: (unit_pow_2_Zp' b)=> bodd.
- move: (ord_odd_pow_2' n bodd)=> b2n1_1.
- move: (unit_Zp_expg b (2^n.+1)).
- rewrite /inZp.
- move/(f_equal val)=> /=.
- rewrite {3}Zp_cast; [| rewrite !expnS; lia].
- rewrite b2n1_1 => eqq.
- apply/eqP.
- rewrite /eq_op /= /eq_op /= eqq.
- rewrite Zp_cast; [| rewrite !expnS; lia].
- rewrite modn_small // !expnS; lia.
-Qed.
-
-Lemma ord_unit_pow_2_Zp' (n : nat) (b : {unit 'Z_(2^n.+3)}) :
- #[b] %| (2^n.+1)%N.
-Proof.
- rewrite order_dvdn.
- apply /eqP.
- apply ord_unit_pow_2_Zp.
-Qed.
-
-Lemma dvdn_prime_power x p n :
- prime p ->
- x %| p^n.+1 ->
- ~ x %| p^n ->
- (x = p^n.+1)%N.
-Proof.
- intros p_prime x_n1 x_n.
- generalize (prime_gt1 p_prime); intros pgt.
- move /dvdn_pfactor in x_n1.
- destruct (x_n1 p_prime).
- rewrite H0 (dvdn_Pexp2l x0 n pgt) in x_n.
- assert (x0 = n.+1) by lia.
- by rewrite H1 in H0.
-Qed.
-
-Lemma ord_unit_pow_2_Zp_max (n : nat) (b : {unit 'Z_(2^n.+3)}) :
- b ^+ (2^n) <> 1 ->
- #[b] = (2^n.+1)%N.
-Proof.
- intros.
- generalize (ord_unit_pow_2_Zp' b); intros.
- assert (~ #[b] %| 2^n).
- {
- intros ?.
- rewrite order_dvdn in H1.
- by move /eqP in H1.
- }
- by apply dvdn_prime_power.
-Qed.
-
-Lemma card_units_pow_2_Zp (n : nat) :
- #|units_Zp (2^n.+1)| = (2^n)%N.
-Proof.
- rewrite card_units_Zp; try lia.
- rewrite totient_pfactor // /=.
- lia.
-Qed.
-
-Lemma unit_Zp_gens_ord (n : nat) (a b : {unit 'Z_n}) :
- #|<[a]>%G :&: <[b]>%G| = 1%N ->
- #|<[a]> * <[b]>| = (#[a] * #[b])%N.
-Proof.
- intros.
- by rewrite mul_cardG H muln1.
-Qed.
-
-Lemma unit_pow_2_Zp_gens_ord (n : nat) (a b : {unit 'Z_(2^n.+2)}) :
- #[a] = 2%N ->
- #[b] = (2^n)%N ->
- #|<[a]>%G :&: <[b]>%G| = 1%N ->
- #|<[a]> * <[b]>| = (2^n.+1)%N.
-Proof.
- intros.
- rewrite (unit_Zp_gens_ord H1) H H0; trivial.
- by rewrite expnS.
-Qed.
-
-Lemma ord2_setI_G1 (n : nat) (a b : {unit 'Z_(2^n)}) :
- #[a] = 2%N ->
- (a \notin <[b]>) ->
- <[a]>%G :&: <[b]>%G = 1.
-Proof.
- intros.
- have ->: (<[a]> :&: <[b]> = [set 1]).
- {
- rewrite (cycle2g H) setIUl.
- have /eqP->: ([set a] :&: <[b]> == set0).
- {
- by rewrite setI_eq0 disjoints1.
- }
- by rewrite setU0 -set1gE setI1g.
- }
- easy.
-Qed.
-
-Lemma ord2_setI (n : nat) (a b : {unit 'Z_(2^n)}) :
- #[a] = 2%N ->
- (a \notin <[b]>) ->
- #|<[a]>%G :&: <[b]>%G| = 1%N.
-Proof.
- intros.
- rewrite ord2_setI_G1; trivial.
- by rewrite cards1.
-Qed.
-
-Lemma unit_pow_2_Zp_gens (n : nat) (a b : {unit 'Z_(2^n.+2)}) :
- #[a] = 2%N ->
- #[b] = (2^n)%N ->
- a \notin <[b]> ->
- <[a]> <*> <[b]> = [group of (units_Zp (2^n.+2)%N)].
-Proof.
- intros.
- generalize (subsetT (<[a]> * <[b]>)%G); intros.
- apply index1g; trivial.
- rewrite -(divgS H2) (card_units_pow_2_Zp n.+1) joinGE /= norm_joinEr /=.
- - rewrite unit_pow_2_Zp_gens_ord //.
- + rewrite divnn !expnS; lia.
- + by apply ord2_setI.
- - apply cents_norm.
- eapply subset_trans.
- apply subsetT.
- apply sub_abelian_cent.
- + apply units_Zp_abelian.
- + apply subsetT.
-Qed.
-
-Lemma unit_3_pow_2_Zp (n : nat) :
- (3 : 'Z_(2^n.+1)) \is a unit.
-Proof.
- rewrite unitZpE.
- - rewrite coprimeXl //.
- - rewrite !expnS; lia.
-Qed.
-
-Lemma unit_5_pow_2_Zp (n : nat) :
- (5 : 'Z_(2^n.+1)) \is a unit.
-Proof.
- rewrite unitZpE.
- - rewrite coprimeXl //.
- - rewrite !expnS; lia.
-Qed.
-
-Lemma unit_odd_pow_2_Zp (j n : nat):
- odd j ->
- (inZp j : 'Z_(2^n.+1)) \is a unit.
-Proof.
- intros.
- rewrite unit_pow_2_Zp /= expnS Zp_cast; [|lia].
- rewrite odd_mod //.
- replace (2 * 2^n)%N with ((2^n).*2) by lia.
- by rewrite odd_double.
-Qed.
-
-Lemma unit_3_pow_2_Zp' (n : nat) :
- (3 : 'Z_(2^n.+1)) \is a unit.
-Proof.
- apply unit_odd_pow_2_Zp.
- rewrite /= expnS Zp_cast; lia.
-Qed.
-
-Lemma m1_not_in_unit_3_pow (n : nat) :
- FinRing.unit 'Z_(2 ^ n.+3) (unitrN1 (Zp_finUnitRingType (Zp_trunc (2 ^ n.+3))))
- \notin <[FinRing.unit 'Z_(2 ^ n.+3) (unit_3_pow_2_Zp n.+2)]>.
-Proof.
- have small1: 1 < 2 ^ n.+3 by (rewrite !expnS; lia).
- have small2: 2 < 2 ^ n.+3 by (rewrite !expnS; lia).
- have small3: 3 < 2 ^ n.+3 by (rewrite !expnS; lia).
- have nexist := @m1_neq_pow3_mod2n n.
- apply/negP.
- move/cyclePmin => [x xlt].
- move/(f_equal (fun (z : {unit 'Z_(2^n.+3)}) => val z)).
- rewrite /= unit_Zp_expg /= {2 3 4 5 6}Zp_cast // !modn_small // /inZp.
- move/(f_equal val) => /=.
- rewrite !Zp_cast // modn_small; [| rewrite !expnS; lia].
- rewrite modn_small // => pow3m1.
- apply nexist.
- exists x.
- rewrite /opp /= /Zp_opp {2}Zp_cast // -inZp_exp.
- apply val_inj.
- by rewrite /= !Zp_cast // -pow3m1 !modn_small //; rewrite !expnS; lia.
-Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_3 (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in
- let u3 := FinRing.unit 'Z_(2^n.+3) (unit_3_pow_2_Zp n.+2) in
- <[um1]> <*> <[u3]> = [group of (units_Zp (2^n.+3)%N)].
-Proof.
- have small3: 3 < 2 ^ n.+3 by (rewrite !expnS; lia).
- have small1: 1 < 2 ^ n.+3 by lia.
- have small2: 2 < 2 ^ n.+3 by lia.
- apply unit_pow_2_Zp_gens.
- - apply nt_prime_order; trivial.
- + apply val_inj.
- by rewrite /= mulrNN mulr1.
- + apply /eqP.
- move/(f_equal FinRing.uval).
- simpl.
- by apply (zp_m1_neq1 small2).
- - apply ord_unit_pow_2_Zp_max.
- generalize (@ord_3_pow_2_neq n); intros.
- move/(f_equal (fun (z : {unit 'Z_(2^n.+3)}) => val z)).
- rewrite unit_Zp_expg /= {2 3 4 5 6}Zp_cast // !modn_small // /inZp.
- move/(f_equal val) => /=.
- rewrite !Zp_cast //.
- - apply m1_not_in_unit_3_pow.
-Qed.
-
-
-Lemma unit_Z4_gens_m1 :
- let um1 := FinRing.unit 'Z_4 (unitrN1 _) in
- <[um1]> = [group of (units_Zp 4)].
-Proof.
- intros.
- generalize (subsetT (<[um1]>)); intros.
- apply index1g; trivial.
- rewrite -(divgS H) (card_units_pow_2_Zp 1).
- assert (#[um1] = 2%N).
- {
- apply nt_prime_order; trivial.
- apply val_inj.
- by rewrite /= mulrNN mulr1.
- }
- unfold order in H0.
- rewrite H0.
- lia.
-Qed.
-
-Lemma unit_Z2 :
- <[1]> = [group of (units_Zp 2)].
-Proof.
- generalize (card_units_pow_2_Zp 0); intros.
- rewrite expn0 expn1 in H.
- apply card1_trivg in H.
- rewrite H.
- apply /eqP.
- apply cycle_eq1.
-Qed.
-
-Lemma unit_Z2_alt :
- let um1 := FinRing.unit 'Z_2 (unitrN1 _) in
- <[um1]> = [group of (units_Zp 2)].
-Proof.
- rewrite -unit_Z2.
- intros.
- f_equal.
- unfold um1.
- apply val_inj; simpl.
- apply val_inj; simpl.
- rewrite Zp_cast; lia.
-Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_3_gen (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+2) (unitrN1 _) in
- let u3 := FinRing.unit 'Z_(2^n.+2) (unit_3_pow_2_Zp n.+1) in
- <[um1]> <*> <[u3]> = [group of (units_Zp (2^n.+2)%N)].
-Proof.
- destruct n.
- - intros.
- rewrite -unit_Z4_gens_m1.
- assert (<[u3]> \subset <[um1]>).
- {
- rewrite cycle_subG.
- assert (u3 = um1).
- {
- unfold u3, um1.
- apply val_inj; simpl.
- apply val_inj; simpl.
- rewrite Zp_cast; lia.
- }
- rewrite H.
- apply cycle_id.
- }
- move /joing_idPl in H.
- rewrite H.
- assert (um1 = FinRing.unit 'Z_4 (unitrN1 (Zp_finUnitRingType (Zp_trunc 4)))).
- {
- by apply val_inj; simpl.
- }
- by rewrite H0.
- - apply unit_pow_2_Zp_gens_m1_3.
-Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_3_gen_gen (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+1) (unitrN1 _) in
- let u3 := FinRing.unit 'Z_(2^n.+1) (unit_3_pow_2_Zp n) in
- <[um1]> <*> <[u3]> = [group of (units_Zp (2^n.+1)%N)].
-Proof.
- destruct n.
- - intros.
- rewrite -unit_Z2_alt.
- assert (<[u3]> \subset <[um1]>).
- {
- rewrite cycle_subG.
- assert (u3 = um1).
- {
- unfold u3, um1.
- apply val_inj; simpl.
- apply val_inj; simpl.
- rewrite Zp_cast; lia.
- }
- rewrite H.
- apply cycle_id.
- }
- move /joing_idPl in H.
- rewrite H.
- assert (um1 = FinRing.unit 'Z_2 (unitrN1 (Zp_finUnitRingType (Zp_trunc 2)))).
- {
- by apply val_inj; simpl.
- }
- by rewrite H0.
- - apply unit_pow_2_Zp_gens_m1_3_gen.
-Qed.
-
-Lemma m1_not_in_unit_5_pow (n : nat) :
- FinRing.unit 'Z_(2 ^ n.+3) (unitrN1 (Zp_finUnitRingType (Zp_trunc (2 ^ n.+3))))
- \notin <[FinRing.unit 'Z_(2 ^ n.+3) (unit_5_pow_2_Zp n.+2)]>.
-Proof.
- have small5: 5 < 2 ^ n.+3 by (rewrite !expnS; lia).
- have small1: 1 < 2 ^ n.+3 by lia.
- have small2: 2 < 2 ^ n.+3 by lia.
- have small3: 3 < 2 ^ n.+3 by lia.
- have small4: 4 < 2 ^ n.+3 by lia.
- generalize (@m1_neq_pow5_mod2n n); intros.
- apply/negP.
- move/cyclePmin => [x xlt].
- move/(f_equal (fun (z : {unit 'Z_(2^n.+3)}) => val z)).
- rewrite /= unit_Zp_expg /= {2 3 4 5 6 7 8 9 10}Zp_cast // !modn_small // /inZp.
- move/(f_equal val) => /=.
- rewrite !Zp_cast // modn_small; [| rewrite !expnS; lia].
- rewrite modn_small // => HH.
- apply H.
- exists x.
- rewrite /opp /= /Zp_opp {2}Zp_cast // -inZp_exp.
- apply val_inj.
- by rewrite /= !Zp_cast // -HH !modn_small //; rewrite !expnS; lia.
-Qed.
-
-Lemma m1_not_in_unit_5_pow_gen (n : nat) :
- FinRing.unit 'Z_(2 ^ n.+2) (unitrN1 (Zp_finUnitRingType (Zp_trunc (2 ^ n.+2))))
- \notin <[FinRing.unit 'Z_(2 ^ n.+2) (unit_5_pow_2_Zp n.+1)]>.
-Proof.
- destruct n.
- - generalize (@m1_neq_pow5_mod2n_gen 0); intros.
- apply /negP.
- move/cyclePmin => [x xlt].
- move/(f_equal (fun (z : {unit 'Z_(2^0.+2)}) => val z)).
- rewrite /= unit_Zp_expg /= {2 3 4 5 6 7 8 9 10}Zp_cast //.
- have small3: 3 < 2 ^ 0.+2 by (rewrite !expnS; lia).
- have small1: 1 < 2 ^ 0.+2 by lia.
- have small2: 2 < 2 ^ 0.+2 by lia.
- rewrite (modn_small small3) (modn_small small1) // /inZp.
- move/(f_equal val) => /=.
- rewrite !Zp_cast //.
- rewrite (modn_small small1) exp1n.
- rewrite !modn_small; try lia.
- - apply m1_not_in_unit_5_pow.
-Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_5 (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in
- let u5 := FinRing.unit 'Z_(2^n.+3) (unit_5_pow_2_Zp n.+2) in
- <[um1]> <*> <[u5]> = [group of (units_Zp (2^n.+3)%N)].
-Proof.
- have small5: 5 < 2 ^ n.+3 by (rewrite !expnS; lia).
- have small1: 1 < 2 ^ n.+3 by lia.
- have small2: 2 < 2 ^ n.+3 by lia.
- have small3: 3 < 2 ^ n.+3 by lia.
- have small4: 4 < 2 ^ n.+3 by lia.
- apply unit_pow_2_Zp_gens.
- - apply nt_prime_order; trivial.
- + apply val_inj.
- by rewrite /= mulrNN mulr1.
- + apply /eqP.
- move/(f_equal FinRing.uval).
- simpl.
- by apply (zp_m1_neq1 small2).
- - apply ord_unit_pow_2_Zp_max.
- generalize (@ord_5_pow_2_neq n); intros.
- move/(f_equal (fun (z : {unit 'Z_(2^n.+3)}) => val z)).
- rewrite unit_Zp_expg /= {2 3 4 5 6 7 8 9 10}Zp_cast // !modn_small // /inZp.
- move/(f_equal val) => /=.
- rewrite !Zp_cast //.
- - apply m1_not_in_unit_5_pow.
-Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_3_alt (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in
- let u3 := FinRing.unit 'Z_(2^n.+3) (unit_3_pow_2_Zp n.+2) in
- <[um1]> * <[u3]> = [group of (units_Zp (2^n.+3)%N)].
-Proof.
- rewrite <- unit_pow_2_Zp_gens_m1_3.
- symmetry.
- apply comm_joingE.
- apply centC.
- apply cents_cycle.
- apply val_inj.
- now rewrite /mulg /FinRing.unit_mul /= /mul /= Zp_mulC.
-Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_3_alt_gen (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+2) (unitrN1 _) in
- let u3 := FinRing.unit 'Z_(2^n.+2) (unit_3_pow_2_Zp n.+1) in
- <[um1]> * <[u3]> = [group of (units_Zp (2^n.+2)%N)].
-Proof.
- rewrite <- unit_pow_2_Zp_gens_m1_3_gen.
- symmetry.
- apply comm_joingE.
- apply centC.
- apply cents_cycle.
- apply val_inj.
- now rewrite /mulg /FinRing.unit_mul /= /mul /= Zp_mulC.
-Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_5_alt (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in
- let u5 := FinRing.unit 'Z_(2^n.+3) (unit_5_pow_2_Zp n.+2) in
- <[um1]> * <[u5]> = [group of (units_Zp (2^n.+3)%N)].
-Proof.
- rewrite <- unit_pow_2_Zp_gens_m1_5.
- symmetry.
- apply comm_joingE.
- apply centC.
- apply cents_cycle.
- apply val_inj.
- now rewrite /mulg /FinRing.unit_mul /= /mul /= Zp_mulC.
-Qed.
-
-Lemma unit_pow_2_2_alt :
- let um1 := FinRing.unit 'Z_(2^2) (unitrN1 _) in
- let u5 := FinRing.unit 'Z_(2^2) (unit_5_pow_2_Zp 1) in
- <[um1]> * <[u5]> = [group of (units_Zp (2^2))].
-Proof.
- intros.
- generalize unit_Z4_gens_m1; intros.
- assert (u5 = 1).
- {
- unfold u5.
- apply val_inj; simpl.
- unfold Zp_trunc.
- apply val_inj; simpl.
- now rewrite modn_small.
- }
- generalize (cycle_eq1 u5); intros.
- move /eqP in H0.
- rewrite -H1 in H0.
- move /eqP in H0.
- rewrite H0.
- rewrite -comm_joingE.
- - rewrite joingG1.
- rewrite -H.
- unfold um1.
- by rewrite /Zp_trunc /=.
- - apply commute1.
-Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_5_alt_gen (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+2) (unitrN1 _) in
- let u5 := FinRing.unit 'Z_(2^n.+2) (unit_5_pow_2_Zp n.+1) in
- <[um1]> * <[u5]> = [group of (units_Zp (2^n.+2)%N)].
-Proof.
- destruct n.
- - apply unit_pow_2_2_alt.
- - apply unit_pow_2_Zp_gens_m1_5_alt.
-Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_3_quo (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in
- let u3 := FinRing.unit 'Z_(2^n.+3) (unit_3_pow_2_Zp n.+2) in
- <[u3]>/<[um1]> = [group of (units_Zp (2^n.+3)%N)]/<[um1]>.
-Proof.
- intros.
- rewrite - quotientMidl.
- by rewrite unit_pow_2_Zp_gens_m1_3_alt.
-Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_3_quo_alt (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in
- let u3 := FinRing.unit 'Z_(2^n.+3) (unit_3_pow_2_Zp n.+2) in
- <[um1]>/<[u3]> = [group of (units_Zp (2^n.+3)%N)]/<[u3]>.
-Proof.
- intros.
- rewrite - quotientMidr.
- by rewrite unit_pow_2_Zp_gens_m1_3_alt.
-Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_5_quo (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+2) (unitrN1 _) in
- let u5 := FinRing.unit 'Z_(2^n.+2) (unit_5_pow_2_Zp n.+1) in
- <[u5]>/<[um1]> = [group of (units_Zp (2^n.+2)%N)]/<[um1]>.
-Proof.
- intros.
- rewrite - quotientMidl.
- by rewrite unit_pow_2_Zp_gens_m1_5_alt_gen.
-Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_5_quo_alt (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+2) (unitrN1 _) in
- let u5 := FinRing.unit 'Z_(2^n.+2) (unit_5_pow_2_Zp n.+1) in
- <[um1]>/<[u5]> = [group of (units_Zp (2^n.+2)%N)]/<[u5]>.
-Proof.
- intros.
- rewrite - quotientMidr.
- by rewrite unit_pow_2_Zp_gens_m1_5_alt_gen.
-Qed.
-
-Lemma quotient_isog_abelian (gT : finGroupType) (A H G : {group gT}) :
- abelian A ->
- H \subset A ->
- G \subset A ->
- H :&: G = 1 ->
- morphism.isog G (G / H).
-Proof.
- intros.
- apply quotient_isog; trivial.
- apply cents_norm.
- by apply (sub_abelian_cent2 H0).
-Qed.
-
-Lemma quotient_isog_unit_Zp (p : nat) (H G : {group {unit 'Z_p}}) :
- H :&: G = 1 ->
- morphism.isog G (G / H).
-Proof.
- intros.
- apply quotient_isog; trivial.
- apply cents_norm.
- eapply subset_trans.
- apply subsetT.
- apply sub_abelian_cent.
- + apply units_Zp_abelian.
- + apply subsetT.
-Qed.
-
-Lemma unitrN1_card_2 (n : nat) :
- #[FinRing.unit 'Z_(2 ^ n.+2) (unitrN1 _)] = 2%N.
-Proof.
- rewrite -(expn1 2).
- apply dvdn_prime_power; trivial.
- - rewrite order_dvdn expn1.
- apply /eqP.
- apply val_inj.
- by rewrite /= mulrNN mulr1.
- - rewrite expn0 /not.
- intros.
- rewrite order_dvdn expg1 expn1 in H.
- move /eqP in H.
- rewrite /oneg /= in H.
- apply (f_equal val) in H.
- rewrite /= /opp /one /= in H.
- apply (f_equal val) in H.
- rewrite /Zp_trunc /= in H.
- have small2: 2 < 2 ^ n.+2 by (rewrite !expnS; lia).
- have small1: 1 < 2 ^ n.+2 by lia.
- replace ((2^n.+2).-2.+2) with (2^n.+2)%N in H by lia.
- rewrite (modn_small small1) modn_small in H; lia.
-Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_3_quo_isog (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in
- let u3 := FinRing.unit 'Z_(2^n.+3) (unit_3_pow_2_Zp n.+2) in
- morphism.isog <[u3]> (<[u3]>/<[um1]>).
-Proof.
- intros.
- apply quotient_isog_unit_Zp.
- apply ord2_setI_G1.
- - apply unitrN1_card_2.
- - apply m1_not_in_unit_3_pow.
- Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_3_quo_isog_um1 (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in
- let u3 := FinRing.unit 'Z_(2^n.+3) (unit_3_pow_2_Zp n.+2) in
- morphism.isog <[um1]> (<[um1]>/<[u3]>).
-Proof.
- intros.
- apply quotient_isog_unit_Zp.
- rewrite setIC.
- apply ord2_setI_G1.
- - apply unitrN1_card_2.
- - apply m1_not_in_unit_3_pow.
- Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_3_quo_isog_alt (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in
- let u3 := FinRing.unit 'Z_(2^n.+3) (unit_3_pow_2_Zp n.+2) in
- morphism.isog <[u3]> ([group of (units_Zp (2^n.+3)%N)]/<[um1]>).
-Proof.
- intros.
- simpl.
- rewrite -unit_pow_2_Zp_gens_m1_3_quo.
- apply unit_pow_2_Zp_gens_m1_3_quo_isog.
-Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_3_quo_isog_um1_alt (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+3) (unitrN1 _) in
- let u3 := FinRing.unit 'Z_(2^n.+3) (unit_3_pow_2_Zp n.+2) in
- morphism.isog <[um1]> ([group of (units_Zp (2^n.+3)%N)]/<[u3]>).
-Proof.
- intros.
- simpl.
- rewrite -unit_pow_2_Zp_gens_m1_3_quo_alt.
- apply unit_pow_2_Zp_gens_m1_3_quo_isog_um1.
-Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_5_quo_isog (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+2) (unitrN1 _) in
- let u5 := FinRing.unit 'Z_(2^n.+2) (unit_5_pow_2_Zp n.+1) in
- morphism.isog <[u5]> (<[u5]>/<[um1]>).
-Proof.
- intros.
- apply quotient_isog_unit_Zp.
- apply ord2_setI_G1.
- - apply unitrN1_card_2.
- - apply m1_not_in_unit_5_pow_gen.
-Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_5_quo_isog_um1 (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+2) (unitrN1 _) in
- let u5 := FinRing.unit 'Z_(2^n.+2) (unit_5_pow_2_Zp n.+1) in
- morphism.isog <[um1]> (<[um1]>/<[u5]>).
-Proof.
- intros.
- apply quotient_isog_unit_Zp.
- rewrite setIC.
- apply ord2_setI_G1.
- - apply unitrN1_card_2.
- - apply m1_not_in_unit_5_pow_gen.
-Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_5_quo_isog_alt (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+2) (unitrN1 _) in
- let u5 := FinRing.unit 'Z_(2^n.+2) (unit_5_pow_2_Zp n.+1) in
- morphism.isog <[u5]> ([group of (units_Zp (2^n.+2)%N)]/<[um1]>).
-Proof.
- intros.
- simpl.
- rewrite -unit_pow_2_Zp_gens_m1_5_quo.
- apply unit_pow_2_Zp_gens_m1_5_quo_isog.
-Qed.
-
-Lemma unit_pow_2_Zp_gens_m1_5_quo_isog_um1_alt (n : nat) :
- let um1 := FinRing.unit 'Z_(2^n.+2) (unitrN1 _) in
- let u5 := FinRing.unit 'Z_(2^n.+2) (unit_5_pow_2_Zp n.+1) in
- morphism.isog <[um1]> ([group of (units_Zp (2^n.+2)%N)]/<[u5]>).
-Proof.
- intros.
- simpl.
- rewrite -unit_pow_2_Zp_gens_m1_5_quo_alt.
- apply unit_pow_2_Zp_gens_m1_5_quo_isog_um1.
-Qed.
-
-End two_pow_units.
-
-From mathcomp Require Import matrix perm.
-Section add_self.
-
- Context {G:ringType}.
- Context {n:nat} (npos:0 < n).
-
- Definition modn_ord (m:nat) : 'I_n := Ordinal (@ltn_pmod m n npos).
-
- Lemma modn_ord_inj (i : 'I_n) :
- injective (fun j : ordinal_finType n => modn_ord (i + j)).
- Proof.
- rewrite /modn_ord => a b [xx].
- have: i + a + (n-i) = i + b + (n - i) %[mod n].
- {
- rewrite -[addn (addn _ (nat_of_ord a)) _ %% n]modnDml.
- rewrite -[addn (addn _ (nat_of_ord b)) _ %% n]modnDml.
- by rewrite xx.
- }
- rewrite (_:(addn (addn i a) (n-i) = addn a n)).
- - rewrite (_:(addn (addn i b) (n-i) = addn b n)).
- + rewrite !modnDr !modn_small; try apply ltn_ord.
- apply val_inj.
- + move: (ltn_ord i); lia.
- - move: (ltn_ord i); lia.
- Qed.
-
- Definition rotate_index_right_ord (idx:'I_n) (e:nat)
- := modn_ord (idx + e).
-
- Lemma rotate_ind_right_ord_cancel (e:nat) :
- cancel (fun (idx : 'I_n) => rotate_index_right_ord idx e)
- (fun (idx : 'I_n) => rotate_index_right_ord idx (n - e %% n)).
- Proof.
- rewrite /rotate_index_right_ord /cancel /modn_ord /=.
- intros.
- apply ord_inj=> /=.
- rewrite -(modnDm x e n) modnDml -addnA.
- generalize (ltn_pmod e npos); intros.
- replace (e %% n + (n - e %% n))%N with n by lia.
- rewrite -modnDmr modnn addn0 modn_mod modn_small //.
- Qed.
-
- Definition rot_perm e := perm (can_inj (rotate_ind_right_ord_cancel e)).
-
- Lemma rot_mul e1 e2 :
- perm_mul (rot_perm e1) (rot_perm e2) = rot_perm (e1 + e2).
- Proof.
- rewrite /perm_mul /rot_perm -permP.
- move => x.
- rewrite permE /= !permE /rotate_index_right_ord /modn_ord.
- apply ord_inj=> /=.
- by rewrite modnDml addnA.
- Qed.
-
- Definition rotate_row_right (v:'rV[G]_n) (e:nat)
- := \row_(i < n) v 0 (rotate_index_right_ord i e).
-
- Definition row_sum_naive_rot_row (v:'rV[G]_n)
- := \sum_(i < n) rotate_row_right v i.
-
- Definition row_sum_naive_rot (v:'rV[G]_n)
- := row_sum_naive_rot_row v 0 (Ordinal npos).
-
- Lemma row_sum_naive_rot_row_correct (v:'rV[G]_n)
- : row_sum_naive_rot_row v = const_mx (\sum_(j < n) v 0 j).
- Proof.
- apply/matrixP => rr i.
- rewrite /row_sum_naive_rot_row/const_mx/rotate_row_right/rotate_index_right_ord.
- rewrite !ord1 !mxE summxE /=.
- rewrite [\sum_(j modn_ord (i + nat_of_ord j)))).
- - apply eq_bigr => k _.
- by rewrite !mxE.
- - apply modn_ord_inj.
- Qed.
-
- Lemma row_sum_naive_rot_correct (v:'rV[G]_n)
- : row_sum_naive_rot v = \sum_(j < n) v 0 j.
- Proof.
- by rewrite/row_sum_naive_rot row_sum_naive_rot_row_correct/const_mx !mxE.
- Qed.
-
-End add_self.
-
-Section add_self_pow.
- Context {G:comRingType}.
-
- Lemma expn_2_pos (n : nat) : 0 < expn 2 n.
- Proof.
- lia.
- Qed.
-
- Fixpoint row_sum_rot_pow_rec {n} (v:'rV[G]_(2^n)) (m : nat) : 'rV[G]_(2^n) :=
- match m with
- | 0 => v
- | S m' => row_sum_rot_pow_rec (v + rotate_row_right (expn_2_pos n) v (2^m')) m'
- end.
-
- Lemma vals_modn_mul_0 i m n :
- 0 < m ->
- 0 < n ->
- i == 0 %[mod m] ->
- exists k, k < n /\ i == k * m %[mod m * n].
- Proof.
- intros.
- assert (exists q, i = (q * m)%N).
- {
- rewrite (modn_small H) in H1.
- by move /dvdnP in H1.
- }
- destruct H2.
- exists (x %% n).
- split.
- - by rewrite ltn_mod.
- - apply /eqP.
- by rewrite mulnC H2 muln_modl modn_mod.
- Qed.
-
- Lemma vals_modn_mul_le i j m n :
- 0 < m ->
- 0 < n ->
- j <= i ->
- i == j %[mod m] ->
- exists k, k < n /\ i == j + k * m %[mod m * n].
- Proof.
- intros.
- rewrite modn_sub in H2; trivial.
- apply (vals_modn_mul_0 (n:=n)) in H2; trivial.
- destruct H2 as [? [??]].
- exists x.
- split; trivial.
- assert (x * m < m * n).
- {
- rewrite mulnC.
- by rewrite ltn_pmul2l.
- }
- rewrite (modn_small H4) in H3.
- move /eqP in H3.
- apply (f_equal (fun z => modn (j + z) (m * n)%N)) in H3.
- apply /eqP.
- rewrite -H3 modnDmr.
- clear H3.
- replace (j + (i - j))%N with i; trivial.
- lia.
- Qed.
-
- Lemma vals_modn_mul i j m n :
- 0 < m ->
- 0 < n ->
- i == j %[mod m] ->
- exists k, k < n /\ i == j + k * m %[mod m * n].
- Proof.
- case (boolP (j <= i)).
- - intros; by apply vals_modn_mul_le.
- - intros.
- rewrite eq_sym in H1.
- apply (vals_modn_mul_le (n := n)) in H1; try lia.
- destruct H1 as [? [??]].
- destruct x.
- + rewrite mul0n addn0 in H2.
- exists 0%N.
- split; trivial.
- by rewrite mul0n addn0 eq_sym.
- + exists (n - x.+1)%N.
- split; try lia.
- move /eqP in H2.
- apply (f_equal (fun z => modn (z + (n - x.+1) * m)%N (m * n)%N)) in H2.
- rewrite modnDml in H2.
- rewrite H2 modnDml.
- replace (i + x.+1 * m + (n - x.+1) * m)%N with (i + m * n)%N.
- * by rewrite -modnDm modnn addn0 modn_mod.
- * clear H2.
- lia.
- Qed.
-
- Lemma vals_mod_2_Sn i j n :
- i == j %[mod 2^n] ->
- i == j %[mod 2^n.+1] \/ i == j + 2^n %[mod 2^n.+1].
- Proof.
- intros.
- case (boolP (i == j %[mod 2^n.+1])).
- - by left.
- - right.
- apply (vals_modn_mul (expn_2_pos n) (n := 2)) in H; try lia.
- destruct H as [? [??]].
- destruct x.
- + rewrite mul0n addn0 in H0.
- by rewrite expnS mulnC H0 in i0.
- + assert (x = 0%N) by lia.
- rewrite H1 mul1n in H0.
- by rewrite expnS mulnC.
- Qed.
-
- Definition is_partitioned_in_same_bins_by_m_to {n} (v:'rV[G]_(2^n)) m :=
- (forall i j, val i == val j %[mod 2^m] -> v 0 i = v 0 j).
-
- Lemma mod_mod_le k m n :
- m <= n ->
- (k %% 2^n) %% 2^m = k %% 2^m.
- Proof.
- intros.
- rewrite modn_dvdm; trivial.
- by apply dvdn_exp2l.
- Qed.
-
- Lemma row_sum_rot_pow_rec_step_narrows_bins {n} (v:'rV[G]_(2^n)) (m : nat) :
- S m <= n ->
- is_partitioned_in_same_bins_by_m_to v (S m) ->
- is_partitioned_in_same_bins_by_m_to (v + rotate_row_right (expn_2_pos n) v (2^m)) m.
- Proof.
- intro le_Sm_n.
- rewrite /is_partitioned_in_same_bins_by_m_to.
- intros.
- rewrite !mxE /rotate_index_right_ord.
- assert (val i == val j %[mod 2^m.+1] \/ val i == val j + 2^m %[mod 2^m.+1]).
- {
- by apply vals_mod_2_Sn.
- }
- destruct H1.
- - f_equal; apply H; simpl.
- + apply H1.
- + rewrite mod_mod_le; trivial.
- rewrite mod_mod_le; trivial.
- move /eqP in H1.
- apply /eqP.
- apply (f_equal (fun z => (z + 2^m)%N %% (2^m.+1))) in H1.
- by rewrite !modnDml in H1.
- - move /eqP in H1.
- assert (j = i + 2^m %[mod 2^m.+1]).
- {
- apply (f_equal (fun z => (z + 2^m)%N %% (2^m.+1))) in H1.
- rewrite !modnDml in H1.
- rewrite H1 -addnA.
- replace (2^m + 2^m)%N with (2^m.+1).
- - by rewrite modnDr.
- - rewrite expnS; lia.
- }
- rewrite addrC.
- f_equal; apply H; simpl.
- + by rewrite (mod_mod_le (i + 2^m)%N le_Sm_n) -H2.
- + by rewrite (mod_mod_le (j + 2^m)%N le_Sm_n) -H1.
- Qed.
-
- Lemma row_sum_rot_pow_rec_step_preserves_sum {n} (v:'rV[G]_(2^n)) (m : nat)
- (pf1:2^m.+1<=2^n) (pf2:2^m<=2^n):
- \sum_(i < 2^S m) v 0 (widen_ord pf1 i) =
- \sum_(i < 2^m) (v + rotate_row_right (expn_2_pos n) v (2^m)) 0 (widen_ord pf2 i).
- Proof.
- intros.
- under [\sum_(i < 2 ^ m) _] eq_bigr do rewrite !mxE.
-
- rewrite /= big_split /=.
-
- have pf1': (2 ^ m + 2 ^ m <= 2 ^ n) by (rewrite (leq_trans _ pf1) // expnS; lia).
- transitivity (\sum_(i < 2 ^ m + 2 ^ m) v 0 (widen_ord pf1' i)).
- - have: 2 ^ m.+1 = addn (2 ^ m) (2 ^ m)
- by (rewrite expnS; lia).
- destruct 1.
- apply eq_bigr => i _.
- by f_equal; apply val_inj.
- - rewrite big_split_ord /=.
- f_equal; (apply eq_bigr => i _; f_equal; apply val_inj => //=).
- rewrite addnC modn_small //.
- have leq1: i < 2 ^ m by apply ltn_ord.
- lia.
- Qed.
-
- Definition row_sum_rot_pow {n} (v:'rV[G]_(2^n)) := row_sum_rot_pow_rec v n.
-
- Lemma is_partitioned_in_same_bins_by_m_to0 {n} (v:'rV[G]_(2^n)) :
- is_partitioned_in_same_bins_by_m_to v n.
- Proof.
- move=> i j eqq.
- suff ->: i = j => //.
- rewrite !modn_small in eqq.
- - apply val_inj.
- by apply/eqP.
- - apply ltn_ord.
- - apply ltn_ord.
- Qed.
-
- Lemma row_sum_rot_pow_is_really_binned {n} (v:'rV[G]_(2^n)) :
- is_partitioned_in_same_bins_by_m_to (row_sum_rot_pow v) 0.
- Proof.
- rewrite /row_sum_rot_pow.
- suff {v}: forall n', n' <= n -> forall v : 'rV_(2 ^ n),
- is_partitioned_in_same_bins_by_m_to v n' ->
- is_partitioned_in_same_bins_by_m_to (row_sum_rot_pow_rec v n') 0.
- { apply.
- - apply leqnn.
- - apply is_partitioned_in_same_bins_by_m_to0.
- }
- induction n' => //= n'l v.
- move/row_sum_rot_pow_rec_step_narrows_bins => HH.
- apply IHn'.
- - by apply ltnW.
- - by apply HH.
- Qed.
-
- Lemma row_sum_rot_pow_is_summed {n} (v:'rV[G]_(2^n)) :
- \sum_(i < 2^n) v 0 i = (row_sum_rot_pow v) 0 (Ordinal (expn_2_pos n)).
- Proof.
- rewrite /row_sum_rot_pow.
- suff {v} HH: forall n' (pf:2^n'<=2^n), forall v' : 'rV_(2 ^ n),
- \sum_(i < 2^n') v' 0 (widen_ord pf i) =
- (row_sum_rot_pow_rec v' n') 0 (Ordinal (expn_2_pos n)).
- {
- rewrite <- (HH n (leqnn _) v).
- - apply eq_big => // i _.
- f_equal.
- by apply ord_inj.
- }
- induction n' => /= pf v'.
- - rewrite (big_pred1_id _ _ _ _ (i:=0)).
- + rewrite addr0.
- f_equal.
- by apply val_inj.
- + move=> i /=.
- by rewrite ord1 eqxx.
- - have pf': is_true (2 ^ n' <= 2 ^ n).
- {
- rewrite (leq_trans _ pf) // expnS.
- lia.
- }
- rewrite <- (IHn' pf').
- by apply row_sum_rot_pow_rec_step_preserves_sum.
- Qed.
-
- (* claim at kth iteration v is a concatenation of 2^k equal vectors each of which has the same sum as the original v. *)
- Lemma row_sum_rot_pow_correct {n} (v:'rV[G]_(2^n))
- : row_sum_rot_pow v = const_mx (\sum_(j < 2^n) v 0 j).
- Proof.
- apply/matrixP => rr i.
- rewrite !ord1 /const_mx !mxE row_sum_rot_pow_is_summed /row_sum_rot_pow .
- apply row_sum_rot_pow_is_really_binned.
- by rewrite expn0 !modn1.
- Qed.
-
-End add_self_pow.
diff --git a/coq/NeuralNetworks/AxiomaticNormedRealVectorSpace.v b/coq/NeuralNetworks/AxiomaticNormedRealVectorSpace.v
deleted file mode 100644
index d808c207..00000000
--- a/coq/NeuralNetworks/AxiomaticNormedRealVectorSpace.v
+++ /dev/null
@@ -1,45 +0,0 @@
-(*using definitions from http://www.math.ucla.edu/~tao/resource/general/121.1.00s/vector_axioms.html *)
-
-Require Import Reals.Rbase.
-Require Import Reals.Rfunctions.
-
-Module AxiomaticNormedRealVectorSpace.
-
-Inductive rvector (d: nat) : Set :=
-| zero
-| add (x y : rvector d)
-| inverse (x: rvector d)
-| smult (r:R) (x: rvector d).
-
-Class NormedVectorSpace (d: nat) :=
- {
- norm: rvector d -> R;
-
- rv_axiom_additive : forall x y z : rvector d,
- (add d x y) = (add d y x) /\
- (add d (add d x y) z) = (add d x (add d y z)) /\
- (add d (zero d) x) = (add d x (zero d)) /\
- (add d x (zero d)) = x /\
- (add d (inverse d x) x) = (add d x (inverse d x)) /\
- (add d x (inverse d x)) = zero d;
-
- rv_axiom_multiplicaive : forall (x : rvector d), forall (b c : R),
- smult d R0 x = zero d /\
- smult d R1 x = x /\
- smult d (Rmult b c) x = smult d b (smult d c x);
-
- rv_axiom_distributive : forall (x y : rvector d), forall (a b : R),
- smult d b (add d x y) = add d (smult d b x) (smult d b y) /\
- smult d (Rplus a b) x = add d (smult d a x) (smult d b x);
-
- norm_axiom_zero : forall (x:rvector d),
- norm (zero d) = R0 <-> x=zero d;
-
- norm_axiom_abs : forall (x: rvector d), forall (a:R),
- norm (smult d a x) = Rmult (Rabs a) (norm x);
-
- norm_axiom_add : forall (x y : rvector d),
- norm (add d x y) = Rplus (norm x) (norm y);
- }.
-
-End AxiomaticNormedRealVectorSpace.
diff --git a/coq/NeuralNetworks/DefinedFunctions.v b/coq/NeuralNetworks/DefinedFunctions.v
deleted file mode 100644
index 408fbb8e..00000000
--- a/coq/NeuralNetworks/DefinedFunctions.v
+++ /dev/null
@@ -1,15581 +0,0 @@
-Require Import Program.
-Require Import String.
-Require Import EquivDec.
-Require Import RelationClasses.
-Require Import List.
-Require Import Permutation.
-Require Import NPeano BinInt PeanoNat.
-Require Import Lra Lia.
-Require Reals.
-Require Import Eqdep_dec.
-
-Require Import Floatish.
-Require Import Utils.
-Require Import derivlemmas.
-Require Import Vector.
-
-Import ListNotations.
-
-Local Open Scope list_scope.
-Declare Scope df_scope.
-
-Set Bullet Behavior "Strict Subproofs".
-
-Section DefinedFunctions.
-
- Context {floatish_impl:floatish}.
- Local Open Scope float.
-
-(* in pytorch relu(f)' if f <=0 then 0 else f' *)
-(* in pytorch abs(f)' = f'*sign(f) *)
-(* max(a,b)' = if a<=b then b' else a' *)
-(* min(a,b)' = if a>=b then b' else a' *)
-(* x = Variable(torch.tensor(0.0), requires_grad=True) *)
-(* z = torch.min(x*x, x); z.backward(); print(x.grad) = 1 *)
-(* x.grad.data.zero_() between tests *)
-(* relu behaves like max(x, 0), not max(0,x), i.e. relu(x)' at 0 = 0 *)
-
-
- Section Definitions.
-
- Definition var := string.
-
- Inductive SubVar : Set :=
- | Name (s : string)
- | Sub (v : SubVar) (i : nat).
-
-
- Definition var_dec : forall v1 v2 : SubVar, {v1 = v2} + {v1 <> v2}.
- Proof.
- decide equality.
- apply string_dec.
- apply Nat.eq_dec.
- Defined.
-
- Global Instance var_eqdec : EqDec SubVar eq.
- Proof.
- intros x y.
- apply var_dec.
- Defined.
-
- (* A subset of defined functions *)
-
-
- Inductive definition_function_types
- := DTfloat
- | DTVector (n:nat)
- | DTMatrix (m n:nat).
-
- Definition definition_function_types_interp (dft:definition_function_types) : Type
- := match dft with
- | DTfloat => float
- | DTVector n => Vector float n
- | DTMatrix m n => Matrix float m n
- end.
-
- Inductive data_type : definition_function_types -> Type
- := DataFloat : data_type DTfloat
- | DataVector n (v:Vector float n) : data_type (DTVector n)
- | DataMatrix m n (mat:Matrix float m n) : data_type (DTMatrix m n).
-
- Definition var_type := (SubVar * definition_function_types)%type.
-
- Definition definition_function_types_dec : forall v1 v2 : definition_function_types, {v1 = v2} + {v1 <> v2}.
- Proof.
- decide equality; apply Nat.eq_dec.
- Defined.
-
- Definition vart_dec : forall v1 v2 : var_type, {v1 = v2} + {v1 <> v2}.
- Proof.
- decide equality.
- - apply definition_function_types_dec.
- - apply var_dec.
- Defined.
-
- Global Instance vart_eqdec : EqDec var_type eq.
- Proof.
- intros ??.
- apply vart_dec.
- Defined.
-
- Lemma var_type_UIP_refl {x:var_type} (e:x=x) : e = eq_refl x.
- Proof.
- apply (UIP_dec vart_dec).
- Qed.
-
- Lemma definition_function_types_UIP_refl {x:definition_function_types} (e:x=x) : e = eq_refl x.
- Proof.
- apply (UIP_dec definition_function_types_dec).
- Qed.
-
- Definition env_entry_type := {v:var_type & definition_function_types_interp (snd v)}.
- Definition df_env := list env_entry_type.
-
- Definition mk_env_entry v e : env_entry_type
- := let P := fun xv => definition_function_types_interp (snd xv) in
- existT P v e.
-
- Definition UnitAnn: definition_function_types->Type := fun _ => unit.
- Definition EvalAnn: definition_function_types->Type := definition_function_types_interp.
-
- Inductive DefinedFunction {Ann:definition_function_types->Type} : definition_function_types -> Type :=
- | Number (ann:Ann DTfloat) (x : float) : DefinedFunction DTfloat
- | Constant {t:definition_function_types} (ann:Ann t) (x : definition_function_types_interp t) : DefinedFunction t
- | DVector {n} (ann:Ann (DTVector n)) (x : Vector (DefinedFunction DTfloat) n) : DefinedFunction (DTVector n)
- | DMatrix {n m} (ann:Ann (DTMatrix n m)) (x : Matrix (DefinedFunction DTfloat) n m) : DefinedFunction (DTMatrix n m)
- | Var (v : var_type) (ann: Ann (snd v)) : DefinedFunction (snd v)
- | Plus (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | Minus (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | Times (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | Divide (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | Square (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | Exp (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | Log (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | Abs (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | Sign (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | PSign (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | Max (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | VectorDot {n} (ann:Ann DTfloat) (l r: DefinedFunction (DTVector n)) : DefinedFunction DTfloat
- | VectorSum {n} (ann:Ann DTfloat) (v: DefinedFunction (DTVector n)) : DefinedFunction DTfloat
- | MatrixSum {m n} (ann:Ann DTfloat) (v: DefinedFunction (DTMatrix m n)) : DefinedFunction DTfloat
- | VectorElem {n} (ann:Ann DTfloat) (l:DefinedFunction (DTVector n)) (i:{x:nat|x Prop)
- (f : forall (ann : UnitAnn DTfloat) (x : float),
- P DTfloat (Number ann x))
- (f0 : forall (t : definition_function_types)
- (ann : UnitAnn t) (x : definition_function_types_interp t), P t (Constant ann x))
- (f1 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (x : Vector (DefinedFunction UnitAnn DTfloat) n),
- (forall s : {n' : nat | (n' < n)%nat}, P DTfloat (x s)) ->
- P (DTVector n) (DVector ann x))
- (f2 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m))
- (x : Matrix (DefinedFunction UnitAnn DTfloat) n m),
- (forall (s : {n' : nat | (n' < n)%nat}) (s0 : {m' : nat | (m' < m)%nat}),
- P DTfloat (x s s0)) -> P (DTMatrix n m) (DMatrix ann x))
- (f3 : forall (v : var_type) (ann : UnitAnn (snd v)),
- P (snd v) (Var v ann))
- (f4 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Plus ann l r))
- (f5 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Minus ann l r))
- (f6 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Times ann l r))
- (f7 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Divide ann l r))
- (f8 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Square ann e))
- (f9 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Exp ann e))
- (f10 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Log ann e))
- (f11 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Abs ann e))
- (f12 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Sign ann e))
- (f13 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (PSign ann e))
- (f14 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Max ann l r))
- (f15 : forall (n : nat) (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn (DTVector n)),
- P (DTVector n) l ->
- forall r : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) r -> P DTfloat (VectorDot ann l r))
- (f16 : forall (n : nat) (ann : UnitAnn DTfloat)
- (v : DefinedFunction UnitAnn (DTVector n)), P (DTVector n) v -> P DTfloat (VectorSum ann v))
- (f17 : forall (m n : nat) (ann : UnitAnn DTfloat)
- (v : DefinedFunction UnitAnn (DTMatrix m n)),
- P (DTMatrix m n) v -> P DTfloat (MatrixSum ann v))
- (f18 : forall (n : nat) (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn (DTVector n)),
- P (DTVector n) l ->
- forall i : {x : nat | (x < n)%nat}, P DTfloat (VectorElem ann l i))
- (f19 : forall (m n : nat) (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn (DTMatrix m n)),
- P (DTMatrix m n) l ->
- forall (i : {x : nat | (x < m)%nat}) (j : {x : nat | (x < n)%nat}),
- P DTfloat (MatrixElem ann l i j))
- (f20 : forall (m n : nat) (ann : UnitAnn (DTVector m))
- (l : DefinedFunction UnitAnn (DTMatrix m n)),
- P (DTMatrix m n) l ->
- forall r : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) r -> P (DTVector m) (MatrixVectorMult ann l r))
- (f21 : forall (m n : nat) (ann : UnitAnn (DTMatrix m n))
- (l : DefinedFunction UnitAnn (DTMatrix m n)),
- P (DTMatrix m n) l ->
- forall r : DefinedFunction UnitAnn (DTVector m),
- P (DTVector m) r -> P (DTMatrix m n) (MatrixVectorAdd ann l r))
- (f22 : forall (m p n : nat) (ann : UnitAnn (DTMatrix m n))
- (l : DefinedFunction UnitAnn (DTMatrix m p)),
- P (DTMatrix m p) l ->
- forall r : DefinedFunction UnitAnn (DTMatrix p n),
- P (DTMatrix p n) r -> P (DTMatrix m n) (MatrixMult ann l r))
- (f23 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (l : DefinedFunction UnitAnn (DTVector n)),
- P (DTVector n) l ->
- forall r : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) r -> P (DTVector n) (VectorPlus ann l r))
- (f24 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (l : DefinedFunction UnitAnn (DTVector n)),
- P (DTVector n) l ->
- forall r : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) r -> P (DTVector n) (VectorMinus ann l r))
- (f25 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m))
- (l : DefinedFunction UnitAnn (DTMatrix n m)),
- P (DTMatrix n m) l ->
- forall r : DefinedFunction UnitAnn (DTMatrix n m),
- P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixPlus ann l r))
- (f26 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m))
- (l : DefinedFunction UnitAnn (DTMatrix n m)),
- P (DTMatrix n m) l ->
- forall r : DefinedFunction UnitAnn (DTMatrix n m),
- P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixMinus ann l r))
- (f27 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (x : DefinedFunction UnitAnn DTfloat),
- P DTfloat x ->
- forall l : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) l -> P (DTVector n) (VectorScalMult ann x l))
- (f28 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m))
- (x : DefinedFunction UnitAnn DTfloat),
- P DTfloat x ->
- forall l : DefinedFunction UnitAnn (DTMatrix n m),
- P (DTMatrix n m) l -> P (DTMatrix n m) (MatrixScalMult ann x l))
- (f29 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (v : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- P DTfloat s ->
- forall l : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) l -> P (DTVector n) (VectorApply ann v s l))
- (f30 : forall (m n : nat) (ann : UnitAnn (DTMatrix m n))
- (v : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- P DTfloat s ->
- forall l : DefinedFunction UnitAnn (DTMatrix m n),
- P (DTMatrix m n) l -> P (DTMatrix m n) (MatrixApply ann v s l))
- (f31 : forall (n : nat) (ann : UnitAnn DTfloat)
- (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- P DTfloat s ->
- forall l : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) l -> forall r : Vector float n, P DTfloat (VLossfun ann v1 v2 s l r))
- (f32 : forall (m n : nat) (ann : UnitAnn DTfloat)
- (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- P DTfloat s ->
- forall l : DefinedFunction UnitAnn (DTMatrix m n),
- P (DTMatrix m n) l ->
- forall r : Matrix float m n, P DTfloat (MLossfun ann v1 v2 s l r))
- :=
-fix
-F (d : definition_function_types)
- (d0 : DefinedFunction UnitAnn d) {struct d0} : P d d0 :=
- match d0 as d2 in (DefinedFunction _ d1) return (P d1 d2) with
- | Number ann x => f ann x
- | @Constant _ t ann x => f0 t ann x
- | @DVector _ n ann x => f1 n ann x (fun s : {n' : nat | (n' < n)%nat} => F DTfloat (x s))
- | @DMatrix _ n m ann x =>
- f2 n m ann x
- (fun (s : {n' : nat | (n' < n)%nat}) (s0 : {m' : nat | (m' < m)%nat}) =>
- F DTfloat (x s s0))
- | Var v ann => f3 v ann
- | Plus ann l r => f4 ann l (F DTfloat l) r (F DTfloat r)
- | Minus ann l r => f5 ann l (F DTfloat l) r (F DTfloat r)
- | Times ann l r => f6 ann l (F DTfloat l) r (F DTfloat r)
- | Divide ann l r => f7 ann l (F DTfloat l) r (F DTfloat r)
- | Square ann e => f8 ann e (F DTfloat e)
- | Exp ann e => f9 ann e (F DTfloat e)
- | Log ann e => f10 ann e (F DTfloat e)
- | Abs ann e => f11 ann e (F DTfloat e)
- | Sign ann e => f12 ann e (F DTfloat e)
- | PSign ann e => f13 ann e (F DTfloat e)
- | Max ann l r => f14 ann l (F DTfloat l) r (F DTfloat r)
- | @VectorDot _ n ann l r => f15 n ann l (F (DTVector n) l) r (F (DTVector n) r)
- | @VectorSum _ n ann v => f16 n ann v (F (DTVector n) v)
- | @MatrixSum _ m n ann v => f17 m n ann v (F (DTMatrix m n) v)
- | @VectorElem _ n ann l i => f18 n ann l (F (DTVector n) l) i
- | @MatrixElem _ m n ann l i j => f19 m n ann l (F (DTMatrix m n) l) i j
- | @MatrixVectorMult _ m n ann l r =>
- f20 m n ann l (F (DTMatrix m n) l) r (F (DTVector n) r)
- | @MatrixVectorAdd _ m n ann l r =>
- f21 m n ann l (F (DTMatrix m n) l) r (F (DTVector m) r)
- | @MatrixMult _ m p n ann l r =>
- f22 m p n ann l (F (DTMatrix m p) l) r (F (DTMatrix p n) r)
- | @VectorPlus _ n ann l r => f23 n ann l (F (DTVector n) l) r (F (DTVector n) r)
- | @VectorMinus _ n ann l r => f24 n ann l (F (DTVector n) l) r (F (DTVector n) r)
- | @MatrixPlus _ n m ann l r => f25 n m ann l (F (DTMatrix n m) l) r (F (DTMatrix n m) r)
- | @MatrixMinus _ n m ann l r =>
- f26 n m ann l (F (DTMatrix n m) l) r (F (DTMatrix n m) r)
- | @VectorScalMult _ n ann x l => f27 n ann x (F DTfloat x) l (F (DTVector n) l)
- | @MatrixScalMult _ n m ann x l => f28 n m ann x (F DTfloat x) l (F (DTMatrix n m) l)
- | @VectorApply _ n ann v s l => f29 n ann v s (F DTfloat s) l (F (DTVector n) l)
- | @MatrixApply _ m n ann v s l =>
- f30 m n ann v s (F DTfloat s) l (F (DTMatrix m n) l)
- | @VLossfun _ n ann v1 v2 s l r =>
- f31 n ann v1 v2 s (F DTfloat s) l (F (DTVector n) l) r
- | @MLossfun _ m n ann v1 v2 s l r =>
- f32 m n ann v1 v2 s (F DTfloat s) l (F (DTMatrix m n) l) r
- end.
-
-Definition DefinedFunction_ind_simpl {Ann}
- (P : forall (d : definition_function_types), DefinedFunction Ann d -> Prop)
- (f : forall (ann : Ann DTfloat) (x : float),
- P DTfloat (Number ann x))
- (f0 : forall (t : definition_function_types)
- (ann : Ann t) (x : definition_function_types_interp t), P t (Constant ann x))
- (f1 : forall (n : nat) (ann : Ann (DTVector n))
- (x : Vector (DefinedFunction Ann DTfloat) n),
- (forall s : {n' : nat | (n' < n)%nat}, P DTfloat (x s)) ->
- P (DTVector n) (DVector ann x))
- (f2 : forall (n m : nat) (ann : Ann (DTMatrix n m))
- (x : Matrix (DefinedFunction Ann DTfloat) n m),
- (forall (s : {n' : nat | (n' < n)%nat}) (s0 : {m' : nat | (m' < m)%nat}),
- P DTfloat (x s s0)) -> P (DTMatrix n m) (DMatrix ann x))
- (f3 : forall (v : var_type) (ann : Ann (snd v)),
- P (snd v) (Var v ann))
- (f4 : forall (ann : Ann DTfloat)
- (l : DefinedFunction Ann DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Plus ann l r))
- (f5 : forall (ann : Ann DTfloat)
- (l : DefinedFunction Ann DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Minus ann l r))
- (f6 : forall (ann : Ann DTfloat)
- (l : DefinedFunction Ann DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Times ann l r))
- (f7 : forall (ann : Ann DTfloat)
- (l : DefinedFunction Ann DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Divide ann l r))
- (f8 : forall (ann : Ann DTfloat)
- (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Square ann e))
- (f9 : forall (ann : Ann DTfloat)
- (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Exp ann e))
- (f10 : forall (ann : Ann DTfloat)
- (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Log ann e))
- (f11 : forall (ann : Ann DTfloat)
- (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Abs ann e))
- (f12 : forall (ann : Ann DTfloat)
- (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Sign ann e))
- (f13 : forall (ann : Ann DTfloat)
- (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (PSign ann e))
- (f14 : forall (ann : Ann DTfloat)
- (l : DefinedFunction Ann DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Max ann l r))
- (f15 : forall (n : nat) (ann : Ann DTfloat)
- (l : DefinedFunction Ann (DTVector n)),
- P (DTVector n) l ->
- forall r : DefinedFunction Ann (DTVector n),
- P (DTVector n) r -> P DTfloat (VectorDot ann l r))
- (f16 : forall (n : nat) (ann : Ann DTfloat)
- (v : DefinedFunction Ann (DTVector n)), P (DTVector n) v -> P DTfloat (VectorSum ann v))
- (f17 : forall (m n : nat) (ann : Ann DTfloat)
- (v : DefinedFunction Ann (DTMatrix m n)),
- P (DTMatrix m n) v -> P DTfloat (MatrixSum ann v))
- (f18 : forall (n : nat) (ann : Ann DTfloat)
- (l : DefinedFunction Ann (DTVector n)),
- P (DTVector n) l ->
- forall i : {x : nat | (x < n)%nat}, P DTfloat (VectorElem ann l i))
- (f19 : forall (m n : nat) (ann : Ann DTfloat)
- (l : DefinedFunction Ann (DTMatrix m n)),
- P (DTMatrix m n) l ->
- forall (i : {x : nat | (x < m)%nat}) (j : {x : nat | (x < n)%nat}),
- P DTfloat (MatrixElem ann l i j))
- (f20 : forall (m n : nat) (ann : Ann (DTVector m))
- (l : DefinedFunction Ann (DTMatrix m n)),
- P (DTMatrix m n) l ->
- forall r : DefinedFunction Ann (DTVector n),
- P (DTVector n) r -> P (DTVector m) (MatrixVectorMult ann l r))
- (f21 : forall (m n : nat) (ann : Ann (DTMatrix m n))
- (l : DefinedFunction Ann (DTMatrix m n)),
- P (DTMatrix m n) l ->
- forall r : DefinedFunction Ann (DTVector m),
- P (DTVector m) r -> P (DTMatrix m n) (MatrixVectorAdd ann l r))
- (f22 : forall (m p n : nat) (ann : Ann (DTMatrix m n))
- (l : DefinedFunction Ann (DTMatrix m p)),
- P (DTMatrix m p) l ->
- forall r : DefinedFunction Ann (DTMatrix p n),
- P (DTMatrix p n) r -> P (DTMatrix m n) (MatrixMult ann l r))
- (f23 : forall (n : nat) (ann : Ann (DTVector n))
- (l : DefinedFunction Ann (DTVector n)),
- P (DTVector n) l ->
- forall r : DefinedFunction Ann (DTVector n),
- P (DTVector n) r -> P (DTVector n) (VectorPlus ann l r))
- (f24 : forall (n : nat) (ann : Ann (DTVector n))
- (l : DefinedFunction Ann (DTVector n)),
- P (DTVector n) l ->
- forall r : DefinedFunction Ann (DTVector n),
- P (DTVector n) r -> P (DTVector n) (VectorMinus ann l r))
- (f25 : forall (n m : nat) (ann : Ann (DTMatrix n m))
- (l : DefinedFunction Ann (DTMatrix n m)),
- P (DTMatrix n m) l ->
- forall r : DefinedFunction Ann (DTMatrix n m),
- P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixPlus ann l r))
- (f26 : forall (n m : nat) (ann : Ann (DTMatrix n m))
- (l : DefinedFunction Ann (DTMatrix n m)),
- P (DTMatrix n m) l ->
- forall r : DefinedFunction Ann (DTMatrix n m),
- P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixMinus ann l r))
- (f27 : forall (n : nat) (ann : Ann (DTVector n))
- (x : DefinedFunction Ann DTfloat),
- P DTfloat x ->
- forall l : DefinedFunction Ann (DTVector n),
- P (DTVector n) l -> P (DTVector n) (VectorScalMult ann x l))
- (f28 : forall (n m : nat) (ann : Ann (DTMatrix n m))
- (x : DefinedFunction Ann DTfloat),
- P DTfloat x ->
- forall l : DefinedFunction Ann (DTMatrix n m),
- P (DTMatrix n m) l -> P (DTMatrix n m) (MatrixScalMult ann x l))
- (f29 : forall (n : nat) (ann : Ann (DTVector n))
- (v : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- forall l : DefinedFunction Ann (DTVector n),
- P (DTVector n) l -> P (DTVector n) (VectorApply ann v s l))
- (f30 : forall (m n : nat) (ann : Ann (DTMatrix m n))
- (v : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- forall l : DefinedFunction Ann (DTMatrix m n),
- P (DTMatrix m n) l -> P (DTMatrix m n) (MatrixApply ann v s l))
- (f31 : forall (n : nat) (ann : Ann DTfloat)
- (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- forall l : DefinedFunction Ann (DTVector n),
- P (DTVector n) l -> forall r : Vector float n, P DTfloat (VLossfun ann v1 v2 s l r))
- (f32 : forall (m n : nat) (ann : Ann DTfloat)
- (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- forall l : DefinedFunction Ann (DTMatrix m n),
- P (DTMatrix m n) l ->
- forall r : Matrix float m n, P DTfloat (MLossfun ann v1 v2 s l r))
- :=
-fix
-F (d : definition_function_types)
- (d0 : DefinedFunction Ann d) {struct d0} : P d d0 :=
- match d0 as d2 in (DefinedFunction _ d1) return (P d1 d2) with
- | Number ann x => f ann x
- | @Constant _ t ann x => f0 t ann x
- | @DVector _ n ann x => f1 n ann x (fun s : {n' : nat | (n' < n)%nat} => F DTfloat (x s))
- | @DMatrix _ n m ann x =>
- f2 n m ann x
- (fun (s : {n' : nat | (n' < n)%nat}) (s0 : {m' : nat | (m' < m)%nat}) =>
- F DTfloat (x s s0))
- | Var v ann => f3 v ann
- | Plus ann l r => f4 ann l (F DTfloat l) r (F DTfloat r)
- | Minus ann l r => f5 ann l (F DTfloat l) r (F DTfloat r)
- | Times ann l r => f6 ann l (F DTfloat l) r (F DTfloat r)
- | Divide ann l r => f7 ann l (F DTfloat l) r (F DTfloat r)
- | Square ann e => f8 ann e (F DTfloat e)
- | Exp ann e => f9 ann e (F DTfloat e)
- | Log ann e => f10 ann e (F DTfloat e)
- | Abs ann e => f11 ann e (F DTfloat e)
- | Sign ann e => f12 ann e (F DTfloat e)
- | PSign ann e => f13 ann e (F DTfloat e)
- | Max ann l r => f14 ann l (F DTfloat l) r (F DTfloat r)
- | @VectorDot _ n ann l r => f15 n ann l (F (DTVector n) l) r (F (DTVector n) r)
- | @VectorSum _ n ann v => f16 n ann v (F (DTVector n) v)
- | @MatrixSum _ m n ann v => f17 m n ann v (F (DTMatrix m n) v)
- | @VectorElem _ n ann l i => f18 n ann l (F (DTVector n) l) i
- | @MatrixElem _ m n ann l i j => f19 m n ann l (F (DTMatrix m n) l) i j
- | @MatrixVectorMult _ m n ann l r =>
- f20 m n ann l (F (DTMatrix m n) l) r (F (DTVector n) r)
- | @MatrixVectorAdd _ m n ann l r =>
- f21 m n ann l (F (DTMatrix m n) l) r (F (DTVector m) r)
- | @MatrixMult _ m p n ann l r =>
- f22 m p n ann l (F (DTMatrix m p) l) r (F (DTMatrix p n) r)
- | @VectorPlus _ n ann l r => f23 n ann l (F (DTVector n) l) r (F (DTVector n) r)
- | @VectorMinus _ n ann l r => f24 n ann l (F (DTVector n) l) r (F (DTVector n) r)
- | @MatrixPlus _ n m ann l r => f25 n m ann l (F (DTMatrix n m) l) r (F (DTMatrix n m) r)
- | @MatrixMinus _ n m ann l r =>
- f26 n m ann l (F (DTMatrix n m) l) r (F (DTMatrix n m) r)
- | @VectorScalMult _ n ann x l => f27 n ann x (F DTfloat x) l (F (DTVector n) l)
- | @MatrixScalMult _ n m ann x l => f28 n m ann x (F DTfloat x) l (F (DTMatrix n m) l)
- | @VectorApply _ n ann v s l => f29 n ann v s l (F (DTVector n) l)
- | @MatrixApply _ m n ann v s l =>
- f30 m n ann v s l (F (DTMatrix m n) l)
- | @VLossfun _ n ann v1 v2 s l r =>
- f31 n ann v1 v2 s l (F (DTVector n) l) r
- | @MLossfun _ m n ann v1 v2 s l r =>
- f32 m n ann v1 v2 s l (F (DTMatrix m n) l) r
- end.
-
- Definition get_annotation {Ann T} (df:DefinedFunction Ann T) : Ann T
- := match df with
- | Number ann _ => ann
- | Constant _ ann _ => ann
- | DVector _ ann _ => ann
- | DMatrix _ _ ann _ => ann
- | Var _ ann => ann
- | Plus ann _ _ => ann
- | Minus ann _ _ => ann
- | Times ann _ _ => ann
- | Divide ann _ _ => ann
- | Square ann _ => ann
- | Exp ann _ => ann
- | Log ann _ => ann
- | Abs ann _ => ann
- | Sign ann _ => ann
- | PSign ann _ => ann
- | Max ann _ _ => ann
- | VectorDot _ ann _ _ => ann
- | VectorSum _ ann _ => ann
- | MatrixSum _ _ ann _ => ann
- | VectorElem _ ann _ _ => ann
- | MatrixElem _ _ ann _ _ _ => ann
- | MatrixVectorMult _ _ ann _ _ => ann
- | MatrixVectorAdd _ _ ann _ _ => ann
- | MatrixMult _ _ _ ann _ _ => ann
- | VectorPlus _ ann _ _ => ann
- | VectorMinus _ ann _ _ => ann
- | MatrixPlus _ _ ann _ _ => ann
- | MatrixMinus _ _ ann _ _ => ann
- | VectorScalMult _ ann _ _ => ann
- | MatrixScalMult _ _ ann _ _ => ann
- | VectorApply _ ann _ _ _ => ann
- | MatrixApply _ _ ann _ _ _ => ann
- | VLossfun _ ann _ _ _ _ _ => ann
- | MLossfun _ _ ann _ _ _ _ _ => ann
- end.
-
- Definition dft_eq_dec :
- forall (t1 t2 : definition_function_types), {t1 = t2} + {t1 <> t2}.
- Proof.
- decide equality.
- decide equality.
- apply Nat.eq_dec.
- apply Nat.eq_dec.
- Defined.
-
- Global Instance dft_eqdec : EqDec definition_function_types eq.
- Proof.
- intros ??.
- apply dft_eq_dec.
- Defined.
-
- End Definitions.
-
- Tactic Notation "DefinedFunction_cases" tactic(first) ident(c) :=
- first;
- [ Case_aux c "Number"%string
- | Case_aux c "Constant"%string
- | Case_aux c "DVector"%string
- | Case_aux c "DMatrix"%string
- | Case_aux c "Var"%string
- | Case_aux c "Plus"%string
- | Case_aux c "Minus"%string
- | Case_aux c "Times"%string
- | Case_aux c "Divide"%string
- | Case_aux c "Square"%string
- | Case_aux c "Exp"%string
- | Case_aux c "Log"%string
- | Case_aux c "Abs"%string
- | Case_aux c "Sign"%string
- | Case_aux c "PSign"%string
- | Case_aux c "Max"%string
- | Case_aux c "VectorDot"%string
- | Case_aux c "VectorSum"%string
- | Case_aux c "MatrixSum"%string
- | Case_aux c "VectorElem"%string
- | Case_aux c "MatrixElem"%string
- | Case_aux c "MatrixVectorMult"%string
- | Case_aux c "MatrixVectorAdd"%string
- | Case_aux c "MatrixMult"%string
- | Case_aux c "VectorPlus"%string
- | Case_aux c "VectorMinus"%string
- | Case_aux c "MatrixPlus"%string
- | Case_aux c "MatrixMinus"%string
- | Case_aux c "VectorScalMult"%string
- | Case_aux c "MatrixScalMult"%string
- | Case_aux c "VectorApply"%string
- | Case_aux c "MatrixApply"%string
- | Case_aux c "VLossfun"%string
- | Case_aux c "MLossfun"%string].
-
-
- Ltac refl_simpler :=
- repeat
- match goal with
- | [H: @eq var_type _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H)
- | [H: @equiv var_type _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H)
- | [H: @eq definition_function_types _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H)
- | [H: @equiv definition_function_types _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H)
- end.
-
-
- Definition df_plus (df1 df2 : DefinedFunction UnitAnn DTfloat) : DefinedFunction UnitAnn DTfloat :=
- Plus tt df1 df2.
-
- Definition df_times (df1 df2 : DefinedFunction UnitAnn DTfloat) : DefinedFunction UnitAnn DTfloat :=
- Times tt df1 df2.
-
- Definition defined_sum {m} (v:Vector (DefinedFunction UnitAnn DTfloat) m) : DefinedFunction UnitAnn DTfloat
- := vector_fold_right1 (fun a b => Plus tt a b) (Number tt 0) id v.
-
- Definition vsum {m:nat} (v:Vector float m) : float
- := vector_fold_right1 Fplus 0 id v.
-
- Definition msum {m n:nat} (v:Matrix float m n) : float :=
- vsum (vmap vsum v).
-
- Definition matrix_vector_mult {n m} (l : Matrix float n m)(r : Vector float m) : Vector float n :=
- fun i => vsum (fun j => (l i j) * (r j)).
-
- Definition matrix_vector_add {n m} (l : Matrix float n m) (r : Vector float n) : Matrix float n m := fun i j => (l i j) + (r i).
-
- Definition matrix_mult {n m p} (l : Matrix float n m)(r : Matrix float m p) : Matrix float n p :=
- fun i k => vsum (fun j => (l i j) * (r j k)).
-
-
- Section deriv.
-
-
- Section subst.
-
- Definition substvar {Ann} (v vv:var_type) (e':DefinedFunction Ann (snd v)) (e:DefinedFunction Ann (snd vv)) : (DefinedFunction Ann (snd vv)) :=
- match snd v == snd vv with
- | left pf => eq_rect _ (fun t => DefinedFunction Ann t) e' _ pf
- | right pf => e
- end.
-
- Fixpoint df_subst {T Ann} (df: DefinedFunction Ann T) (v:var_type) (e':DefinedFunction UnitAnn (snd v)) :=
- match df with
- | Number _ x => Number tt x
- | Constant t _ x => Constant tt x
- | DVector n _ df => DVector tt (fun x => df_subst (df x) v e')
- | DMatrix n m _ df => DMatrix tt (fun i j => df_subst (df i j) v e')
- | Var vvar _ => substvar v vvar e' (Var vvar tt)
- | Plus _ l r => Plus tt (df_subst l v e') (df_subst r v e')
- | Times _ l r => Times tt (df_subst l v e') (df_subst r v e')
- | Minus _ l r => Minus tt (df_subst l v e') (df_subst r v e')
- | Divide _ l r => Divide tt (df_subst l v e') (df_subst r v e')
- | Square _ e => Square tt (df_subst e v e')
- | Exp _ e => Exp tt (df_subst e v e')
- | Log _ e => Log tt (df_subst e v e')
- | Abs _ e => Abs tt (df_subst e v e')
- | Sign _ e => Sign tt (df_subst e v e')
- | PSign _ e => PSign tt (df_subst e v e')
- | Max _ l r => Max tt (df_subst l v e') (df_subst r v e')
- | VectorElem n _ l i => VectorElem tt (df_subst l v e') i
- | MatrixElem m n _ l i j => MatrixElem tt (df_subst l v e') i j
- | VectorDot n _ l r =>
- VectorDot tt (df_subst l v e') (df_subst r v e')
- | VectorSum n _ e =>
- VectorSum tt (df_subst e v e')
- | MatrixSum n m _ e =>
- MatrixSum tt (df_subst e v e')
- | VectorScalMult n _ x r =>
- VectorScalMult tt (df_subst x v e') (df_subst r v e')
- | MatrixScalMult n m _ x r =>
- MatrixScalMult tt (df_subst x v e') (df_subst r v e')
- | MatrixVectorMult n m _ l r =>
- MatrixVectorMult tt (df_subst l v e') (df_subst r v e')
- | MatrixVectorAdd n m _ l r =>
- MatrixVectorAdd tt (df_subst l v e') (df_subst r v e')
- | MatrixMult n m p _ l r =>
- MatrixMult tt (df_subst l v e') (df_subst r v e')
- | VectorPlus n _ l r =>
- VectorPlus tt (df_subst l v e') (df_subst r v e')
- | VectorMinus n _ l r =>
- VectorMinus tt (df_subst l v e') (df_subst r v e')
- | MatrixPlus n m _ l r =>
- MatrixPlus tt (df_subst l v e') (df_subst r v e')
- | MatrixMinus n m _ l r =>
- MatrixMinus tt (df_subst l v e') (df_subst r v e')
- | VectorApply n _ x s l =>
- VectorApply tt x s (df_subst l v e')
- | MatrixApply n m _ x s l =>
- MatrixApply tt x s (df_subst l v e')
- | VLossfun n _ v1 v2 s l r =>
- VLossfun tt v1 v2 s (df_subst l v e') r
- | MLossfun n m _ v1 v2 s l r =>
- MLossfun tt v1 v2 s (df_subst l v e') r
- end.
-
- Definition df_substp {T Ann} :=
- fun e (ve':{v:var_type & DefinedFunction UnitAnn (snd v)}) =>
- @df_subst T Ann e (projT1 ve') (projT2 ve').
-
- Definition df_subst_list {T} (e:DefinedFunction UnitAnn T)
- (l:list {v:var_type & DefinedFunction UnitAnn (snd v)}) : DefinedFunction UnitAnn T
- := fold_left (@df_substp T UnitAnn) l e.
-
- End subst.
-
-
- (* restrict to scalar v? *)
-
- Fixpoint df_deriv {T} (df:DefinedFunction UnitAnn T) (v:var_type) {struct df} : DefinedFunction UnitAnn T
- := (match df with
- | Number _ _ => Number tt 0
- | Constant t _ x => Constant tt
- match t return definition_function_types_interp t with
- | DTfloat => 0
- | DTVector n => ConstVector n 0
- | DTMatrix m n => ConstMatrix m n 0
- end
- | DVector n _ df => DVector tt (fun x => df_deriv (df x) v)
- | DMatrix n m _ df => DMatrix tt (fun i j => df_deriv (df i j) v)
- | Var x _ => Constant tt
- match snd x as y return definition_function_types_interp y with
- | DTfloat => if x == v then 1 else 0
- | DTVector n => ConstVector n (if x == v then 1 else 0)
- | DTMatrix m n => ConstMatrix m n (if x == v then 1 else 0)
- end
- | Plus _ l r => Plus tt (df_deriv l v) (df_deriv r v)
- | Minus _ l r => Minus tt (df_deriv l v) (df_deriv r v)
- | Times _ l r => Plus tt (Times tt l (df_deriv r v))
- (Times tt (df_deriv l v) r)
- | Divide _ l r => Minus tt
- (Divide tt (df_deriv l v) r)
- (Divide tt (Times tt l (df_deriv r v))
- (Times tt r r))
- | Square _ e => Times tt
- (Times tt (Number tt 2) e) (df_deriv e v)
- | Exp _ e => Times tt (df_deriv e v) (Exp tt e)
- | Log _ e => Divide tt (df_deriv e v) e
- | Abs _ e => Times tt (df_deriv e v) (Sign tt e)
- | Sign _ e => Number tt 0
- | PSign _ e => Number tt 0
- | Max _ l r => Divide tt
- (Plus tt
- (Times tt (Minus tt
- (df_deriv r v)
- (df_deriv l v))
- (PSign tt (Minus tt r l)))
- (Plus tt (df_deriv r v) (df_deriv l v)))
- (Number tt 2)
- | VectorElem n _ l i => VectorElem tt (df_deriv l v) i
- | MatrixElem m n _ l i j => MatrixElem tt (df_deriv l v) i j
- | VectorDot n _ l r =>
- let ll := df_deriv l v in
- let rr := df_deriv r v in
- Plus tt (VectorDot tt ll r) (VectorDot tt l rr)
- | VectorSum n _ l =>
- let ll := df_deriv l v in
- VectorSum tt ll
- | MatrixSum m n _ l =>
- let ll := df_deriv l v in
- MatrixSum tt ll
- | VectorScalMult n _ x r =>
- let xx := df_deriv x v in
- let rr := df_deriv r v in
- VectorPlus tt
- (VectorScalMult tt xx r)
- (VectorScalMult tt x rr)
- | MatrixScalMult n m _ x r =>
- let xx := df_deriv x v in
- let rr := df_deriv r v in
- MatrixPlus tt
- (MatrixScalMult tt xx r) (MatrixScalMult tt x rr)
- | MatrixVectorMult n m _ l r =>
- let ll := df_deriv l v in
- let rr := df_deriv r v in
- VectorPlus tt (MatrixVectorMult tt ll r)
- (MatrixVectorMult tt l rr)
- | MatrixVectorAdd n m _ l r =>
- let ll := df_deriv l v in
- let rr := df_deriv r v in
- MatrixVectorAdd tt ll rr
- | MatrixMult n m p _ l r =>
- let ll := df_deriv l v in
- let rr := df_deriv r v in
- MatrixPlus tt (MatrixMult tt ll r) (MatrixMult tt l rr)
- | VectorPlus n _ l r =>
- let ll := df_deriv l v in
- let rr := df_deriv r v in
- VectorPlus tt ll rr
- | VectorMinus n _ l r =>
- let ll := df_deriv l v in
- let rr := df_deriv r v in
- VectorMinus tt ll rr
- | MatrixPlus n m _ l r =>
- let ll := df_deriv l v in
- let rr := df_deriv r v in
- MatrixPlus tt ll rr
- | MatrixMinus n m _ l r =>
- let ll := df_deriv l v in
- let rr := df_deriv r v in
- MatrixMinus tt ll rr
- | VectorApply n _ x s r =>
- let rr := df_deriv r v in
- let ss := df_deriv s (x, DTfloat) in
- DVector tt (fun i => Times tt (VectorElem tt rr i) (df_subst ss (x, DTfloat) (VectorElem tt r i)))
- | MatrixApply n m _ x s r =>
- let rr := df_deriv r v in
- let ss := df_deriv s (x, DTfloat) in
- DMatrix tt (fun i j => Times tt (MatrixElem tt rr i j) (df_subst ss (x, DTfloat) (MatrixElem tt r i j)))
- | VLossfun n _ v1 v2 s l r =>
- let ll := df_deriv l v in
- let ss := df_deriv s (v1, DTfloat) in
- VectorDot tt ll
- (DVector tt (fun i =>
- df_subst (df_subst ss (v1, DTfloat) (VectorElem tt l i))
- (v2, DTfloat) (Number tt (r i))))
- | MLossfun n m _ v1 v2 s l r =>
- let ll := df_deriv l v in
- let ss := df_deriv s (v1, DTfloat) in
- MatrixSum tt
- (DMatrix tt
- (fun i j =>
- (Divide tt
- (Times tt (MatrixElem tt ll i j)
- (df_subst (df_subst ss (v1, DTfloat) (MatrixElem tt l i j))
- (v2, DTfloat) (Number tt (r i j))))
- (Number tt (FfromZ (Z.of_nat m))))))
- end).
-
- Definition df_gradient {T} (df:DefinedFunction UnitAnn T) (lv:list var_type) : list (DefinedFunction UnitAnn T)
- := map (df_deriv df) lv.
-
- End deriv.
-
- Section eval.
-
- Program
- Fixpoint vartlookup (l:df_env) (a:var_type) :
- option (definition_function_types_interp (snd a))
- := match l with
- | nil => None
- | fv::os => if a == (projT1 fv) then
- Some (eq_rect _ definition_function_types_interp (projT2 fv) _ _)
- else vartlookup os a
- end.
-
- Fixpoint vart_update (l:df_env) (a:var_type) (n:definition_function_types_interp (snd a)) : df_env
- := match l with
- | nil => (mk_env_entry a n)::nil
- | fv::os => if a == (projT1 fv) then
- (mk_env_entry a n)::os else fv::(vart_update os a n)
- end.
-
- Fixpoint df_eval {T Ann} (σ:df_env) (df:DefinedFunction Ann T) : option (definition_function_types_interp T)
- := match df with
- | Number _ r => Some r
- | Constant t _ x => Some x
- | DVector n _ dfs => vectoro_to_ovector (fun i => df_eval σ (dfs i))
- | DMatrix n m _ df => matrixo_to_omatrix (fun i j => df_eval σ (df i j))
- | Var x _ => vartlookup σ x
- | Plus _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (l' + r')
- | _, _ => None
- end
- | Minus _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (l' - r')
- | _, _ => None
- end
- | Times _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (l' * r')
- | _, _ => None
- end
- | Divide _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (l' / r')
- | _, _ => None
- end
- | Square _ e =>
- match df_eval σ e with
- | Some v => Some (v * v)
- | _ => None
- end
- | Exp _ e =>
- match df_eval σ e with
- | Some v => Some (Fexp v)
- | _ => None
- end
- | Log _ e =>
- match df_eval σ e with
- | Some v => Some (Fln v)
- | _ => None
- end
- | Abs _ e =>
- match df_eval σ e with
- | Some v => Some (Fabs v)
- | _ => None
- end
- | Sign _ e =>
- match df_eval σ e with
- | Some v => Some (sign v)
- | _ => None
- end
- | PSign _ e =>
- match df_eval σ e with
- | Some v => Some (pos_sign v)
- | _ => None
- end
- | Max _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (Fmax l' r')
- | _, _ => None
- end
- | VectorElem n _ l i =>
- match (df_eval σ l) with
- | Some l' => Some (l' i)
- | _ => None
- end
- | MatrixElem m n _ l i j =>
- match (df_eval σ l) with
- | Some l' => Some (l' i j)
- | _ => None
- end
- | VectorDot n _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (vsum (fun i => (l' i) * (r' i)))
- | _, _ => None
- end
- | VectorSum n _ l =>
- match df_eval σ l with
- | Some l' => Some (vsum l')
- | _ => None
- end
- | MatrixSum n m _ l =>
- match df_eval σ l with
- | Some l' => Some (msum l')
- | _ => None
- end
- | VectorScalMult n _ x r =>
- match df_eval σ x, df_eval σ r with
- | Some x', Some r' => Some (fun j => x' * (r' j))
- | _, _ => None
- end
- | MatrixScalMult n m _ x r =>
- match df_eval σ x, df_eval σ r with
- | Some x', Some r' => Some (fun i j => x' * (r' i j))
- | _, _ => None
- end
- | MatrixVectorMult n m _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (matrix_vector_mult l' r')
- | _, _ => None
- end
- | MatrixVectorAdd n m _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (matrix_vector_add l' r')
- | _, _ => None
- end
- | MatrixMult n m p _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (matrix_mult l' r')
- | _, _ => None
- end
- | VectorPlus n _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (fun i => (l' i) + (r' i))
- | _, _ => None
- end
- | VectorMinus n _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (fun i => (l' i) - (r' i))
- | _, _ => None
- end
- | MatrixPlus n m _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (fun i j => (l' i j) + (r' i j))
- | _, _ => None
- end
- | MatrixMinus n m _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (fun i j => (l' i j) - (r' i j))
- | _, _ => None
- end
- | VectorApply n _ x s r =>
- match df_eval σ r with
- | Some r' => vectoro_to_ovector
- (fun i =>
- let xv := (x, DTfloat):var_type in
- df_eval (cons (mk_env_entry xv (r' i)) nil) s)
- | _ => None
- end
- | MatrixApply n m _ x s r =>
- match df_eval σ r with
- | Some r' => matrixo_to_omatrix
- (fun i j =>
- let xv := (x, DTfloat):var_type in
- df_eval (cons (mk_env_entry xv (r' i j)) nil) s)
- | _ => None
- end
- | VLossfun n _ v1 v2 s l r =>
- match df_eval σ l with
- | Some l' =>
- match (vectoro_to_ovector
- (fun i =>
- let xv1 := (v1,DTfloat):var_type in
- let xv2 := (v2,DTfloat):var_type in
- df_eval (cons (mk_env_entry xv1 (l' i))
- (cons (mk_env_entry xv2 (r i)) nil)) s)) with
- | Some vv => Some (vsum vv)
- | _ => None
- end
- | _ => None
- end
- | MLossfun n m _ v1 v2 s l r =>
- match df_eval σ l with
- | Some l' =>
- match (matrixo_to_omatrix
- (fun i j =>
- let xv1 := (v1,DTfloat):var_type in
- let xv2 := (v2,DTfloat):var_type in
- df_eval (cons (mk_env_entry xv1 (l' i j))
- (cons (mk_env_entry xv2 (r i j)) nil)) s)) with
- | Some vv => Some ((msum vv) / (FfromZ (Z.of_nat m)))
- | _ => None
- end
- | _ => None
- end
-
- end.
-
- Fixpoint df_eval_tree {T Ann} (σ:df_env) (df:DefinedFunction Ann T) : option (DefinedFunction EvalAnn T)
- := match df with
- | Number _ r => Some (Number r r)
- | Constant t _ x => Some (Constant x x)
- | DVector n _ dfs =>
- match vectoro_to_ovector (fun i => df_eval_tree σ (dfs i)) with
- | Some val => Some (DVector (vmap get_annotation val) val)
- | _ => None
- end
- | DMatrix n m _ df =>
- match matrixo_to_omatrix (fun i j => df_eval_tree σ (df i j)) with
- | Some val => Some (DMatrix
- (vmap (fun x => vmap get_annotation x) val) val)
- | _ => None
- end
- | Var x _ => match vartlookup σ x with
- | Some val => Some (Var x val)
- | _ => None
- end
- | Plus _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => Some (Plus ((get_annotation l') + (get_annotation r'))
- l' r')
- | _, _ => None
- end
- | Minus _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => Some (Minus ((get_annotation l') - (get_annotation r'))
- l' r')
- | _, _ => None
- end
- | Times _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => Some (Times ((get_annotation l') * (get_annotation r'))
- l' r')
- | _, _ => None
- end
- | Divide _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => Some (Divide ((get_annotation l') / (get_annotation r'))
- l' r')
- | _, _ => None
- end
- | Square _ e =>
- match df_eval_tree σ e with
- | Some vv => let v := get_annotation vv in Some (Square (v * v) vv)
- | _ => None
- end
- | Exp _ e =>
- match df_eval_tree σ e with
- | Some vv => Some (Exp (Fexp (get_annotation vv)) vv)
- | _ => None
- end
- | Log _ e =>
- match df_eval_tree σ e with
- | Some vv => Some (Log (Fln (get_annotation vv)) vv)
- | _ => None
- end
- | Abs _ e =>
- match df_eval_tree σ e with
- | Some vv => Some (Abs (Fabs (get_annotation vv)) vv)
- | _ => None
- end
- | Sign _ e =>
- match df_eval_tree σ e with
- | Some vv => Some (Sign (sign (get_annotation vv)) vv)
- | _ => None
- end
- | PSign _ e =>
- match df_eval_tree σ e with
- | Some vv => Some (PSign (pos_sign (get_annotation vv)) vv)
- | _ => None
- end
- | Max _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => Some (Max (Fmax (get_annotation l') (get_annotation r'))
- l' r')
- | _, _ => None
- end
- | VectorElem n _ l i =>
- match (df_eval_tree σ l) with
- | Some l' => let vl' := get_annotation l' in
- Some (VectorElem (vl' i) l' i)
- | _ => None
- end
- | MatrixElem m n _ l i j =>
- match (df_eval_tree σ l) with
- | Some l' => let vl' := get_annotation l' in
- Some (MatrixElem (vl' i j) l' i j)
- | _ => None
- end
- | VectorDot n _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => let vl' := get_annotation l' in
- let vr' := get_annotation r' in
- Some (VectorDot (vsum (fun i => (vl' i) * (vr' i))) l' r')
- | _, _ => None
- end
- | VectorSum n _ l =>
- match df_eval_tree σ l with
- | Some l' => let vl' := get_annotation l' in
- Some (VectorSum (vsum vl') l')
- | _ => None
- end
- | MatrixSum n m _ l =>
- match df_eval_tree σ l with
- | Some l' => let vl' := get_annotation l' in
- Some (MatrixSum (msum vl') l')
- | _ => None
- end
- | VectorScalMult n _ x r =>
- match df_eval_tree σ x, df_eval_tree σ r with
- | Some x', Some r' => let vx' := get_annotation x' in
- let vr' := get_annotation r' in
- let vec : Vector float n := (fun j => vx' * (vr' j)) in
- Some (VectorScalMult vec x' r')
- | _, _ => None
- end
- | MatrixScalMult n m _ x r =>
- match df_eval_tree σ x, df_eval_tree σ r with
- | Some x', Some r' => let vx' := get_annotation x' in
- let vr' := get_annotation r' in
- let mat : Matrix float n m := fun i j => vx' * (vr' i j) in
- Some (MatrixScalMult mat x' r')
- | _, _ => None
- end
- | MatrixVectorMult n m _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => let vl' := get_annotation l' in
- let vr' := get_annotation r' in
- Some (MatrixVectorMult (matrix_vector_mult vl' vr') l' r')
- | _, _ => None
- end
- | MatrixVectorAdd n m _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => let vl' := get_annotation l' in
- let vr' := get_annotation r' in
- Some (MatrixVectorAdd (matrix_vector_add vl' vr') l' r')
- | _, _ => None
- end
- | MatrixMult n m p _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => let vl' := get_annotation l' in
- let vr' := get_annotation r' in
- Some (MatrixMult (matrix_mult vl' vr') l' r')
- | _, _ => None
- end
- | VectorPlus n _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => let vl' := get_annotation l' in
- let vr' := get_annotation r' in
- let vec : Vector float n :=
- fun i => (vl' i) + (vr' i) in
- Some (VectorPlus vec l' r')
- | _, _ => None
- end
- | VectorMinus n _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => let vl' := get_annotation l' in
- let vr' := get_annotation r' in
- let vec : Vector float n :=
- fun i => (vl' i) - (vr' i) in
- Some (VectorMinus vec l' r')
- | _, _ => None
- end
- | MatrixPlus n m _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => let vl' := get_annotation l' in
- let vr' := get_annotation r' in
- let mat : Matrix float n m :=
- fun i j => (vl' i j) + (vr' i j) in
- Some (MatrixPlus mat l' r')
- | _, _ => None
- end
- | MatrixMinus n m _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => let vl' := get_annotation l' in
- let vr' := get_annotation r' in
- let mat : Matrix float n m :=
- fun i j => (vl' i j) - (vr' i j) in
- Some (MatrixMinus mat l' r')
- | _, _ => None
- end
-
- | VectorApply n _ x s r =>
- match df_eval_tree σ r with
- | Some r' =>
- let vr' := get_annotation r' in
- match vectoro_to_ovector
- (fun i =>
- let xv := (x, DTfloat):var_type in
- df_eval (cons (mk_env_entry xv (vr' i)) nil) s) with
- | Some val => Some (VectorApply val x s r')
- | _ => None
- end
- | _ => None
- end
- | MatrixApply n m _ x s r =>
- match df_eval_tree σ r with
- | Some r' =>
- let vr' := get_annotation r' in
- match matrixo_to_omatrix
- (fun i j =>
- let xv := (x, DTfloat):var_type in
- df_eval (cons (mk_env_entry xv (vr' i j)) nil) s) with
- | Some val => Some (MatrixApply val x s r')
- | _ => None
- end
- | _ => None
- end
- | VLossfun n _ v1 v2 s l r =>
- match df_eval_tree σ l with
- | Some l' =>
- let vl' := get_annotation l' in
- match (vectoro_to_ovector
- (fun i =>
- let xv1 := (v1,DTfloat):var_type in
- let xv2 := (v2,DTfloat):var_type in
- df_eval (cons (mk_env_entry xv1 (vl' i))
- (cons (mk_env_entry xv2 (r i)) nil)) s)) with
- | Some vv => Some (VLossfun (vsum vv) v1 v2 s l' r)
- | _ => None
- end
- | _ => None
- end
- | MLossfun n m _ v1 v2 s l r =>
- match df_eval_tree σ l with
- | Some l' =>
- let vl' := get_annotation l' in
- match (matrixo_to_omatrix
- (fun i j =>
- let xv1 := (v1,DTfloat):var_type in
- let xv2 := (v2,DTfloat):var_type in
- df_eval (cons (mk_env_entry xv1 (vl' i j))
- (cons (mk_env_entry xv2 (r i j)) nil)) s)) with
- | Some vv => Some (MLossfun ((msum vv)/(FfromZ (Z.of_nat m))) v1 v2 s l' r)
- | _ => None
- end
- | _ => None
- end
- end.
-
- Definition eval_env_entry_type := {T:definition_function_types & (DefinedFunction UnitAnn T) & definition_function_types_interp T}.
- Definition df_eval_env := list eval_env_entry_type.
-
- Definition mk_eval_env_entry {T} df val : eval_env_entry_type
- := let P := fun t => DefinedFunction UnitAnn t in
- let Q := fun t => definition_function_types_interp t in
- existT2 P Q T df val.
-
- Definition pair_update_evals {T} (df:DefinedFunction UnitAnn T) (val:definition_function_types_interp T) (dfevals : df_eval_env) : (definition_function_types_interp T * df_eval_env) :=
- (val, (mk_eval_env_entry df val)::dfevals).
-
- Fixpoint df_evals_list {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (dfevals : df_eval_env) : option (definition_function_types_interp T * df_eval_env)
- := match df with
- | Number _ r => Some (pair_update_evals (Number tt r) r dfevals)
- | Constant t _ x => Some (pair_update_evals (Constant tt x) x dfevals)
- | DVector n _ dfs => None (*vectoro_to_ovector (fun i => df_eval σ (dfs i))*)
- | DMatrix n m _ df => None (*matrixo_to_omatrix (fun i j => df_eval σ (df i j))*)
- | Var x _ =>
- match vartlookup σ x with
- | Some val => Some (pair_update_evals (Var x tt) val dfevals)
- | _ => None
- end
- | Plus _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') => Some (pair_update_evals (Plus tt l r) (l'+r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | Minus _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') => Some (pair_update_evals (Minus tt l r) (l'-r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | Times _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') => Some (pair_update_evals (Times tt l r) (l'*r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | Divide _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') => Some (pair_update_evals (Divide tt l r) (l'/r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | Square _ e =>
- match df_evals_list σ e dfevals with
- | Some (v, dfevals') => Some (pair_update_evals (Square tt e) (v * v) dfevals')
- | _ => None
- end
- | Exp _ e =>
- match df_evals_list σ e dfevals with
- | Some (v, dfevals') => Some (pair_update_evals (Exp tt e) (Fexp v) dfevals')
- | _ => None
- end
- | Log _ e =>
- match df_evals_list σ e dfevals with
- | Some (v, dfevals') => Some (pair_update_evals (Log tt e) (Fln v) dfevals')
- | _ => None
- end
- | Abs _ e =>
- match df_evals_list σ e dfevals with
- | Some (v, dfevals') => Some (pair_update_evals (Abs tt e) (Fabs v) dfevals')
- | _ => None
- end
- | Sign _ e =>
- match df_evals_list σ e dfevals with
- | Some (v, dfevals') => Some (pair_update_evals (Sign tt e) (sign v) dfevals')
- | _ => None
- end
- | PSign _ e =>
- match df_evals_list σ e dfevals with
- | Some (v, dfevals') => Some (pair_update_evals (PSign tt e) (pos_sign v) dfevals')
- | _ => None
- end
- | Max _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') => Some (pair_update_evals (Max tt l r) (Fmax l' r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | VectorElem n _ l i =>
- match (df_evals_list σ l dfevals) with
- | Some (l', dfevals') => Some (pair_update_evals (VectorElem tt l i) (l' i) dfevals')
- | _ => None
- end
- | MatrixElem m n _ l i j =>
- match (df_evals_list σ l dfevals) with
- | Some (l', dfevals') => Some (pair_update_evals (MatrixElem tt l i j) (l' i j) dfevals')
- | _ => None
- end
- | VectorDot n _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (VectorDot tt l r)
- (vsum (fun i => (l' i) * (r' i))) dfevals'')
- | _ => None
- end
- | _ => None
- end
- | VectorSum n _ l =>
- match df_evals_list σ l dfevals with
- | Some (l',dfevals') => Some (pair_update_evals (VectorSum tt l) (vsum l') dfevals')
- | _ => None
- end
- | MatrixSum n m _ l =>
- match df_evals_list σ l dfevals with
- | Some (l',dfevals') => Some (pair_update_evals (MatrixSum tt l) (msum l') dfevals')
- | _ => None
- end
- | VectorScalMult n _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (VectorScalMult tt l r) (fun j => l' * (r' j)) dfevals'')
- | _ => None
- end
- | _ => None
- end
- | MatrixScalMult n m _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (MatrixScalMult tt l r) (fun i j => l' * (r' i j)) dfevals'')
- | _ => None
- end
- | _ => None
- end
- | MatrixVectorMult n m _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (MatrixVectorMult tt l r)
- (matrix_vector_mult l' r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | MatrixVectorAdd n m _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (MatrixVectorAdd tt l r)
- (matrix_vector_add l' r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | MatrixMult n m p _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (MatrixMult tt l r) (matrix_mult l' r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | VectorPlus n _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (VectorPlus tt l r)
- (fun i => (l' i) + (r' i)) dfevals'')
- | _ => None
- end
- | _ => None
- end
- | VectorMinus n _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (VectorMinus tt l r)
- (fun i => (l' i) - (r' i)) dfevals'')
- | _ => None
- end
- | _ => None
- end
- | MatrixPlus n m _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (MatrixPlus tt l r)
- (fun i j => (l' i j) + (r' i j)) dfevals'')
- | _ => None
- end
- | _ => None
- end
- | MatrixMinus n m _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (MatrixMinus tt l r)
- (fun i j => (l' i j) - (r' i j)) dfevals'')
- | _ => None
- end
- | _ => None
- end
- | VectorApply n _ x s r =>
- match df_evals_list σ r dfevals with
-(* | Some r' => vectoro_to_ovector
- (fun i =>
- let xv := (x, DTfloat):var_type in
- df_eval (cons (mk_env_entry xv (r' i)) σ) s) *)
- | _ => None
- end
- | VLossfun n _ v1 v2 s l r =>
- match df_evals_list σ l dfevals with
-(* | Some l' =>
- match (vectoro_to_ovector
- (fun i =>
- let xv1 := (v1,DTfloat):var_type in
- let xv2 := (v2,DTfloat):var_type in
- df_eval (cons (mk_env_entry xv1 (l' i))
- (cons (mk_env_entry xv2 (r i)) σ)) s)) with
- | Some vv => Some (vsum vv)
- | _ => None
- end *)
- | _ => None
- end
- | _ => None
- end.
-
-(*
- Program
- Fixpoint evalslookup {T} (l:df_eval_env) (df:DefinedFunction UnitAnn T) :
- option (definition_function_types_interp T)
- := match l with
- | nil => None
- | fv::os => if T == (projT1 (sigT_of_sigT2 fv)) then
- if df == (projT2 (sigT_of_sigT2 fv)) then
- Some (eq_rect _ definition_function_types_interp (projT3 fv) _ _)
- else evalslookup os df
- else evalslookup os df
- end.
-*)
- Definition df_eval_symbolic_gradient {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (lv:list var_type) : option (list (definition_function_types_interp T))
- := listo_to_olist (map (df_eval σ) (df_gradient df lv)).
-
- End eval.
-
- Section isderiv.
-
- Context (σ:df_env).
- Context (v:SubVar).
-(*
- Inductive is_deriv : DefinedFunction -> float -> Prop
- :=
- | is_deriv_Number (x : float) : is_deriv (Number x) 0
- | is_deriv_Var_eq : is_deriv (Var v) 1
- | is_deriv_Var_neq (sv : SubVar) : sv <> v -> is_deriv (Var sv) 0
- | is_deriv_Plus l l' r r' :
- is_deriv l l' ->
- is_deriv r r' ->
- is_deriv (Plus l r) (l' + r')
- | is_deriv_Minus l l' r r' :
- is_deriv l l' ->
- is_deriv r r' ->
- is_deriv (Minus l r) (l' - r')
- | is_deriv_Times l le l' r re r' :
- df_eval σ l = Some le ->
- is_deriv l l' ->
- df_eval σ r = Some re ->
- is_deriv r r' ->
- is_deriv (Times l r) ((le * r') + (l' * re))
- | is_deriv_Divide l le l' r re r' :
- df_eval σ l = Some le ->
- is_deriv l l' ->
- df_eval σ r = Some re ->
- is_deriv r r' ->
- is_deriv (Times l r)
- (((l' * re ) - (le * r'))
- / (re * re))
- | is_deriv_Exp e ee e' :
- df_eval σ e = Some ee ->
- is_deriv e e' ->
- is_deriv (Exp e) (e' * (Fexp ee))
- | is_deriv_Log e ee e' :
- df_eval σ e = Some ee ->
- is_deriv e e' ->
- is_deriv (Exp e) (e' / ee)
- | is_deriv_Abs e ee e' :
- df_eval σ e = Some ee ->
- is_deriv e e' -> is_deriv (Abs e) (e' * (sign ee))
- | is_deriv_Sign (e : DefinedFunction) :
- is_deriv (Sign e) 0
- | is_deriv_PSign (e : DefinedFunction) :
- is_deriv (PSign e) 0
- | is_deriv_Max_l l le l' re r :
- df_eval σ l = Some le ->
- df_eval σ r = Some re ->
- (le > re) = true ->
- is_deriv l l' ->
- is_deriv (Max l r) l'
- | is_deriv_Max_r l le r re r' :
- df_eval σ l = Some le ->
- df_eval σ r = Some re ->
- (re >= le) = true ->
- is_deriv r r' ->
- is_deriv (Max l r) r'.
- (*
- | is_deriv_Max_eq l l' ee r r' :
- df_eval σ l = Some ee ->
- df_eval σ r = Some ee ->
- is_deriv l l' ->
- is_deriv r r' ->
- is_deriv (Max l r) ((l' + r')/2) *)
-
-*)
- End isderiv.
-
- Section deriv2.
-
- Fixpoint df_eval_deriv {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v:var_type) : option (definition_function_types_interp T)
- := (match df with
- | Number _ _ => Some 0
- | Constant t _ x => Some
- match t return definition_function_types_interp t with
- | DTfloat => 0
- | DTVector n => ConstVector n 0
- | DTMatrix m n => ConstMatrix m n 0
- end
- | DVector n _ dfs => vectoro_to_ovector (fun i => df_eval_deriv σ (dfs i) v)
- | DMatrix n m _ df => matrixo_to_omatrix (fun i j => df_eval_deriv σ (df i j) v)
- | Var x _ => Some (let t:=snd x in
- match t return definition_function_types_interp t with
- | DTfloat => if x == v then 1 else 0
- | DTVector n => ConstVector n (if x == v then 1 else 0)
- | DTMatrix m n => ConstMatrix m n (if x == v then 1 else 0)
- end)
- | Plus _ l r =>
- match df_eval_deriv σ l v, df_eval_deriv σ r v with
- | Some le, Some lr => Some (le + lr)
- | _, _ => None
- end
- | Minus _ l r =>
- match df_eval_deriv σ l v, df_eval_deriv σ r v with
- | Some le, Some lr => Some (le - lr)
- | _, _ => None
- end
- | Times _ l r =>
- match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (le * rd +
- (ld * re))
- | _, _, _, _ => None
- end
- | Divide _ l r =>
- match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some ((ld / re) - ((le * rd) / (re * re)))
- | _, _, _, _ => None
- end
- | Square _ e =>
- match df_eval σ e, df_eval_deriv σ e v with
- | Some ee, Some ed => Some (2 * ee * ed)
- | _, _ => None
- end
- | Exp _ e =>
- match df_eval σ e, df_eval_deriv σ e v with
- | Some ee, Some ed => Some (ed * Fexp ee)
- | _, _ => None
- end
- | Log _ e =>
- match df_eval σ e, df_eval_deriv σ e v with
- | Some ee, Some ed => Some (ed / ee)
- | _, _ => None
- end
- | Abs _ e =>
- match df_eval σ e, df_eval_deriv σ e v with
- | Some ee, Some ed => Some (ed * (sign ee))
- | _, _ => None
- end
- | Sign _ e =>
- match df_eval_deriv σ e v with
- | Some _ => Some 0
- | None => None
- end
- | PSign _ e =>
- match df_eval_deriv σ e v with
- | Some _ => Some 0
- | None => None
- end
- | Max _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some le, Some re =>
- if le <= re then df_eval_deriv σ r v else df_eval_deriv σ l v
- | _, _ => None
- end
- | VectorElem n _ l i =>
- match (df_eval_deriv σ l v) with
- | Some l' => Some (l' i)
- | _ => None
- end
- | MatrixElem m n _ l i j =>
- match (df_eval_deriv σ l v) with
- | Some l' => Some (l' i j)
- | _ => None
- end
- | VectorDot n _ l r =>
- match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (vsum (fun j => (le j) * (rd j) + (ld j) * (re j)))
- | _, _, _, _ => None
- end
- | VectorSum n _ l =>
- match df_eval_deriv σ l v with
- | Some ld =>
- Some (vsum ld)
- | _ => None
- end
- | MatrixSum n m _ l =>
- match df_eval_deriv σ l v with
- | Some ld =>
- Some (msum ld)
- | _ => None
- end
- | VectorScalMult n _ x r =>
- match df_eval σ x, df_eval_deriv σ x v, df_eval σ r, df_eval_deriv σ r v with
- | Some xe, Some xd, Some re, Some rd => Some (fun j => xe * (rd j) + xd * (re j))
- | _, _, _, _ => None
- end
- | MatrixScalMult n m _ x r =>
- match df_eval σ x, df_eval_deriv σ x v, df_eval σ r, df_eval_deriv σ r v with
- | Some xe, Some xd, Some re, Some rd => Some (fun i j => xe * (rd i j) + xd * (re i j))
- | _, _, _, _ => None
- end
- | MatrixVectorMult n m _ l r =>
- match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (fun i => vsum (fun j => (le i j)*(rd j) + (ld i j)*(re j)))
- | _, _, _, _ => None
- end
- | MatrixVectorAdd n m _ l r =>
- match df_eval_deriv σ l v, df_eval_deriv σ r v with
- | Some le, Some re =>
- Some (fun i j => (le i j) + (re i))
- | _, _ => None
- end
- | MatrixMult n m p _ l r =>
- match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (fun i k => vsum (fun j => (le i j)*(rd j k) + (ld i j)*(re j k)))
- | _, _, _, _ => None
- end
- | VectorPlus n _ l r =>
- match df_eval_deriv σ l v, df_eval_deriv σ r v with
- | Some l', Some r' => Some (fun i => (l' i) + (r' i))
- | _, _ => None
- end
- | VectorMinus n _ l r =>
- match df_eval_deriv σ l v, df_eval_deriv σ r v with
- | Some l', Some r' => Some (fun i => (l' i) - (r' i))
- | _, _ => None
- end
- | MatrixPlus n m _ l r =>
- match df_eval_deriv σ l v, df_eval_deriv σ r v with
- | Some l', Some r' => Some (fun i j => (l' i j) + (r' i j))
- | _, _ => None
- end
- | MatrixMinus n m _ l r =>
- match df_eval_deriv σ l v, df_eval_deriv σ r v with
- | Some l', Some r' => Some (fun i j => (l' i j) - (r' i j))
- | _, _ => None
- end
- | VectorApply n _ x s r =>
- match df_eval σ r, df_eval_deriv σ r v with
- | Some re, Some rd =>
- vectoro_to_ovector
- (fun i =>
- let xv := (x, DTfloat):var_type in
- match df_eval_deriv (cons (mk_env_entry xv (re i)) nil) s xv with
- | Some sd => Some ((rd i) * sd)
- | _ => None
- end)
- | _, _ => None
- end
- | MatrixApply n m _ x s r =>
- match df_eval σ r, df_eval_deriv σ r v with
- | Some re, Some rd =>
- matrixo_to_omatrix
- (fun i j =>
- let xv := (x, DTfloat):var_type in
- match df_eval_deriv (cons (mk_env_entry xv (re i j)) nil) s xv with
- | Some sd => Some ((rd i j) * sd)
- | _ => None
- end)
- | _, _ => None
- end
- | VLossfun n _ v1 v2 s l r =>
- match df_eval σ l, df_eval_deriv σ l v with
- | Some le, Some ld =>
- match (vectoro_to_ovector
- (fun i =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- match df_eval_deriv (cons (mk_env_entry xv1 (le i))
- (cons (mk_env_entry xv2 (r i)) nil)) s xv1 with
- | Some sd => Some ((ld i) * sd)
- | _ => None
- end)) with
- | Some vv => Some (vsum vv)
- | _ => None
- end
- | _, _ => None
- end
- | MLossfun n m _ v1 v2 s l r =>
- match df_eval σ l, df_eval_deriv σ l v with
- | Some le, Some ld =>
- match (matrixo_to_omatrix
- (fun i j =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- match df_eval_deriv (cons (mk_env_entry xv1 (le i j))
- (cons (mk_env_entry xv2 (r i j)) nil)) s xv1 with
- | Some sd => Some ((ld i j) * sd)
- | _ => None
- end)) with
- | Some vv => Some ((msum vv)/(FfromZ (Z.of_nat m)))
- | _ => None
- end
- | _, _ => None
- end
- end).
-
- Definition mk_genvar_env (s:SubVar) := mk_env_entry (s, DTfloat) (FfromZ (Z.of_nat 1)) :: nil.
-
- (* the v environment below pairs variables with their derivatives *)
- (* in some sense this is giving a directional derivative defined by v *)
- Fixpoint df_eval_deriv_genvar {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v:df_env) : option (definition_function_types_interp T)
- := (match df with
- | Number _ _ => Some 0
- | Constant t _ x => Some
- match t return definition_function_types_interp t with
- | DTfloat => 0
- | DTVector n => ConstVector n 0
- | DTMatrix m n => ConstMatrix m n 0
- end
- | DVector n _ dfs => vectoro_to_ovector (fun i => df_eval_deriv_genvar σ (dfs i) v)
- | DMatrix n m _ df => matrixo_to_omatrix (fun i j => df_eval_deriv_genvar σ (df i j) v)
- | Var x _ => Some (
- match vartlookup v x with
- | Some val => val
- | _ =>
- match (snd x) with
- | DTfloat => 0
- | DTVector n => ConstVector n 0
- | DTMatrix m n => ConstMatrix m n 0
- end
- end)
- | Plus _ l r =>
- match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with
- | Some le, Some lr => Some (le + lr)
- | _, _ => None
- end
- | Minus _ l r =>
- match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with
- | Some le, Some lr => Some (le - lr)
- | _, _ => None
- end
- | Times _ l r =>
- match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (le * rd +
- (ld * re))
- | _, _, _, _ => None
- end
- | Divide _ l r =>
- match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some ((ld / re) - ((le * rd) / (re * re)))
- | _, _, _, _ => None
- end
- | Square _ e =>
- match df_eval σ e, df_eval_deriv_genvar σ e v with
- | Some ee, Some ed => Some (2 * ee * ed)
- | _, _ => None
- end
- | Exp _ e =>
- match df_eval σ e, df_eval_deriv_genvar σ e v with
- | Some ee, Some ed => Some (ed * Fexp ee)
- | _, _ => None
- end
- | Log _ e =>
- match df_eval σ e, df_eval_deriv_genvar σ e v with
- | Some ee, Some ed => Some (ed / ee)
- | _, _ => None
- end
- | Abs _ e =>
- match df_eval σ e, df_eval_deriv_genvar σ e v with
- | Some ee, Some ed => Some (ed * (sign ee))
- | _, _ => None
- end
- | Sign _ e =>
- match df_eval_deriv_genvar σ e v with
- | Some _ => Some 0
- | None => None
- end
- | PSign _ e =>
- match df_eval_deriv_genvar σ e v with
- | Some _ => Some 0
- | None => None
- end
- | Max _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some le, Some re =>
- if le <= re then df_eval_deriv_genvar σ r v else df_eval_deriv_genvar σ l v
- | _, _ => None
- end
- | VectorElem n _ l i =>
- match (df_eval_deriv_genvar σ l v) with
- | Some l' => Some (l' i)
- | _ => None
- end
- | MatrixElem m n _ l i j =>
- match (df_eval_deriv_genvar σ l v) with
- | Some l' => Some (l' i j)
- | _ => None
- end
- | VectorDot n _ l r =>
- match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (vsum (fun j => (le j) * (rd j) + (ld j) * (re j)))
- | _, _, _, _ => None
- end
- | VectorSum n _ l =>
- match df_eval_deriv_genvar σ l v with
- | Some ld =>
- Some (vsum ld)
- | _ => None
- end
- | MatrixSum n m _ l =>
- match df_eval_deriv_genvar σ l v with
- | Some ld =>
- Some (msum ld)
- | _ => None
- end
- | VectorScalMult n _ x r =>
- match df_eval σ x, df_eval_deriv_genvar σ x v, df_eval σ r, df_eval_deriv_genvar σ r v with
- | Some xe, Some xd, Some re, Some rd => Some (fun j => xe * (rd j) + xd * (re j))
- | _, _, _, _ => None
- end
- | MatrixScalMult n m _ x r =>
- match df_eval σ x, df_eval_deriv_genvar σ x v, df_eval σ r, df_eval_deriv_genvar σ r v with
- | Some xe, Some xd, Some re, Some rd => Some (fun i j => xe * (rd i j) + xd * (re i j))
- | _, _, _, _ => None
- end
- | MatrixVectorMult n m _ l r =>
- match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (fun i => vsum (fun j => (le i j)*(rd j) + (ld i j)*(re j)))
- | _, _, _, _ => None
- end
- | MatrixVectorAdd n m _ l r =>
- match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with
- | Some le, Some re =>
- Some (fun i j => (le i j) + (re i))
- | _, _ => None
- end
- | MatrixMult n m p _ l r =>
- match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (fun i k => vsum (fun j => (le i j)*(rd j k) + (ld i j)*(re j k)))
- | _, _, _, _ => None
- end
- | VectorPlus n _ l r =>
- match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with
- | Some l', Some r' => Some (fun i => (l' i) + (r' i))
- | _, _ => None
- end
- | VectorMinus n _ l r =>
- match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with
- | Some l', Some r' => Some (fun i => (l' i) - (r' i))
- | _, _ => None
- end
- | MatrixPlus n m _ l r =>
- match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with
- | Some l', Some r' => Some (fun i j => (l' i j) + (r' i j))
- | _, _ => None
- end
- | MatrixMinus n m _ l r =>
- match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with
- | Some l', Some r' => Some (fun i j => (l' i j) - (r' i j))
- | _, _ => None
- end
- | VectorApply n _ x s r =>
- match df_eval σ r, df_eval_deriv_genvar σ r v with
- | Some re, Some rd =>
- vectoro_to_ovector
- (fun i =>
- let xv := (x, DTfloat):var_type in
- match df_eval_deriv_genvar (mk_env_entry xv (re i) :: nil) s
- (mk_genvar_env x) with
- | Some sd => Some ((rd i) * sd)
- | _ => None
- end)
- | _, _ => None
- end
- | MatrixApply n m _ x s r =>
- match df_eval σ r, df_eval_deriv_genvar σ r v with
- | Some re, Some rd =>
- matrixo_to_omatrix
- (fun i j =>
- let xv := (x, DTfloat):var_type in
- match df_eval_deriv_genvar (mk_env_entry xv (re i j) :: nil) s
- (mk_genvar_env x) with
- | Some sd => Some ((rd i j) * sd)
- | _ => None
- end)
- | _, _ => None
- end
- | VLossfun n _ v1 v2 s l r =>
- match df_eval σ l, df_eval_deriv_genvar σ l v with
- | Some le, Some ld =>
- match (vectoro_to_ovector
- (fun i =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- match df_eval_deriv_genvar ( mk_env_entry xv1 (le i) ::
- mk_env_entry xv2 (r i) :: nil) s
- (mk_genvar_env v1) with
- | Some sd => Some ((ld i) * sd)
- | _ => None
- end)) with
- | Some vv => Some (vsum vv)
- | _ => None
- end
- | _, _ => None
- end
- | MLossfun n m _ v1 v2 s l r =>
- match df_eval σ l, df_eval_deriv_genvar σ l v with
- | Some le, Some ld =>
- match (matrixo_to_omatrix
- (fun i j =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- match df_eval_deriv_genvar ( mk_env_entry xv1 (le i j) ::
- mk_env_entry xv2 (r i j) :: nil) s
- ( mk_genvar_env v1) with
- | Some sd => Some ((ld i j) * sd)
- | _ => None
- end)) with
- | Some vv => Some ((msum vv)/(FfromZ (Z.of_nat m)))
- | _ => None
- end
- | _, _ => None
- end
- end).
-
-
- Definition definition_function_types_interp_prod (vart dft:definition_function_types) : Type
- := match vart with
- | DTfloat => definition_function_types_interp dft
- | DTVector n => Vector (definition_function_types_interp dft) n
- | DTMatrix m n => Matrix (definition_function_types_interp dft) m n
- end.
-
-
- Definition UnitVector (n:nat) (j : {n':nat | (n' < n)%nat}) : Vector float n :=
- fun i => if (proj1_sig i) == (proj1_sig j) then 1 else 0.
-
- Definition UnitMatrix (n m: nat)
- (i : {n':nat | (n' < n)%nat})
- (j : {m':nat | (m' < m)%nat}) : Matrix float n m :=
- fun a b => if (proj1_sig a) == (proj1_sig i) then
- (if (proj1_sig b) == (proj1_sig j) then 1 else 0)
- else 0.
-
- Definition const_env (v : var_type) : df_env
- := match (snd v) with
- | DTfloat => ((mk_env_entry (fst v, DTfloat) 0)::nil)
- | DTVector n => ((mk_env_entry (fst v, DTVector n) (ConstVector n 0))::nil)
- | DTMatrix m n => ((mk_env_entry (fst v, DTMatrix m n) (ConstMatrix m n 0))::nil)
- end.
-
- Fixpoint df_eval_tree_deriv {T} (σ:df_env) (df:DefinedFunction EvalAnn T) (v:var_type) : option (definition_function_types_interp T)
- := (match df with
- | Number _ _ => Some 0
- | Constant t _ x => Some
- match t return definition_function_types_interp t with
- | DTfloat => 0
- | DTVector n => ConstVector n 0
- | DTMatrix m n => ConstMatrix m n 0
- end
- | DVector n _ dfs => vectoro_to_ovector (fun i => df_eval_tree_deriv σ (dfs i) v)
- | DMatrix n m _ df => matrixo_to_omatrix (fun i j => df_eval_tree_deriv σ (df i j) v)
- | Var x _ => Some (let t:=snd x in
- match t return definition_function_types_interp t with
- | DTfloat => if x == v then 1 else 0
- | DTVector n => ConstVector n (if x == v then 1 else 0)
- | DTMatrix m n => ConstMatrix m n (if x == v then 1 else 0)
- end)
- | Plus _ l r =>
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some le, Some lr => Some (le + lr)
- | _, _ => None
- end
- | Minus _ l r =>
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some le, Some lr => Some (le - lr)
- | _, _ => None
- end
- | Times _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some ld, Some rd =>
- Some (le * rd +
- (ld * re))
- | _, _ => None
- end
- | Divide _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some ld, Some rd =>
- Some ((ld / re) - ((le * rd) / (re * re)))
- | _, _ => None
- end
- | Square _ e =>
- let ee := get_annotation e in
- match df_eval_tree_deriv σ e v with
- | Some ed => Some (2 * ee * ed)
- | _ => None
- end
- | Exp _ e =>
- let ee := get_annotation e in
- match df_eval_tree_deriv σ e v with
- | Some ed => Some (ed * Fexp ee)
- | _ => None
- end
- | Log _ e =>
- let ee := get_annotation e in
- match df_eval_tree_deriv σ e v with
- | Some ed => Some (ed / ee)
- | _ => None
- end
- | Abs _ e =>
- let ee := get_annotation e in
- match df_eval_tree_deriv σ e v with
- | Some ed => Some (ed * (sign ee))
- | _ => None
- end
- | Sign _ e =>
- match df_eval_tree_deriv σ e v with
- | Some _ => Some 0
- | None => None
- end
- | PSign _ e =>
- match df_eval_tree_deriv σ e v with
- | Some _ => Some 0
- | None => None
- end
- | Max _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- if le <= re then df_eval_tree_deriv σ r v else df_eval_tree_deriv σ l v
- | VectorElem n _ l i =>
- match (df_eval_tree_deriv σ l v) with
- | Some l' => Some (l' i)
- | _ => None
- end
- | MatrixElem m n _ l i j =>
- match (df_eval_tree_deriv σ l v) with
- | Some l' => Some (l' i j)
- | _ => None
- end
- | VectorDot n _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some ld, Some rd =>
- Some (vsum (fun j => (le j) * (rd j) + (ld j) * (re j)))
- | _, _ => None
- end
- | VectorSum n _ l =>
- match df_eval_tree_deriv σ l v with
- | Some ld =>
- Some (vsum ld)
- | _ => None
- end
- | MatrixSum n m _ l =>
- match df_eval_tree_deriv σ l v with
- | Some ld =>
- Some (msum ld)
- | _ => None
- end
- | VectorScalMult n _ x r =>
- let '(xe,re) := (get_annotation x, get_annotation r) in
- match df_eval_tree_deriv σ x v, df_eval_tree_deriv σ r v with
- | Some xd, Some rd => Some (fun j => xe * (rd j) + xd * (re j))
- | _, _ => None
- end
- | MatrixScalMult n m _ x r =>
- let '(xe,re) := (get_annotation x, get_annotation r) in
- match df_eval_tree_deriv σ x v, df_eval_tree_deriv σ r v with
- | Some xd, Some rd => Some (fun i j => xe * (rd i j) + xd * (re i j))
- | _, _ => None
- end
- | MatrixVectorMult n m _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some ld, Some rd =>
- Some (fun i => vsum (fun j => (le i j)*(rd j) + (ld i j)*(re j)))
- | _, _ => None
- end
- | MatrixVectorAdd n m _ l r =>
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some ld, Some rd =>
- Some (fun i j => (ld i j) + (rd i))
- | _, _ => None
- end
- | MatrixMult n m p _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some ld, Some rd =>
- Some (fun i k => vsum (fun j => (le i j)*(rd j k) + (ld i j)*(re j k)))
- | _, _ => None
- end
- | VectorPlus n _ l r =>
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some l', Some r' => Some (fun i => (l' i) + (r' i))
- | _, _ => None
- end
- | VectorMinus n _ l r =>
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some l', Some r' => Some (fun i => (l' i) - (r' i))
- | _, _ => None
- end
- | MatrixPlus n m _ l r =>
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some l', Some r' => Some (fun i j => (l' i j) + (r' i j))
- | _, _ => None
- end
- | MatrixMinus n m _ l r =>
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some l', Some r' => Some (fun i j => (l' i j) - (r' i j))
- | _, _ => None
- end
- | VectorApply n _ x s r =>
- let re := get_annotation r in
- match df_eval_tree_deriv σ r v with
- | Some rd =>
- vectoro_to_ovector
- (fun i =>
- let xv := (x, DTfloat):var_type in
- match df_eval_deriv (cons (mk_env_entry xv (re i)) nil ) s xv with
- | Some sd => Some ((rd i) * sd)
- | _ => None
- end)
- | _ => None
- end
- | MatrixApply n m _ x s r =>
- let re := get_annotation r in
- match df_eval_tree_deriv σ r v with
- | Some rd =>
- matrixo_to_omatrix
- (fun i j =>
- let xv := (x, DTfloat):var_type in
- match df_eval_deriv (cons (mk_env_entry xv (re i j)) nil ) s xv with
- | Some sd => Some ((rd i j) * sd)
- | _ => None
- end)
- | _ => None
- end
- | VLossfun n _ v1 v2 s l r =>
- let le := get_annotation l in
- match df_eval_tree_deriv σ l v with
- | Some ld =>
- match (vectoro_to_ovector
- (fun i =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- match df_eval_deriv (cons (mk_env_entry xv1 (le i))
- (cons (mk_env_entry xv2 (r i)) nil)) s xv1 with
- | Some sd => Some ((ld i) * sd)
- | _ => None
- end)) with
- | Some vv => Some (vsum vv)
- | _ => None
- end
- | _ => None
- end
- | MLossfun n m _ v1 v2 s l r =>
- let le := get_annotation l in
- match df_eval_tree_deriv σ l v with
- | Some ld =>
- match (matrixo_to_omatrix
- (fun i j =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- match df_eval_deriv (cons (mk_env_entry xv1 (le i j))
- (cons (mk_env_entry xv2 (r i j)) nil )) s xv1 with
- | Some sd => Some ((ld i j) * sd)
- | _ => None
- end)) with
- | Some vv => Some ((msum vv) / (FfromZ (Z.of_nat m)))
- | _ => None
- end
- | _ => None
- end
- end).
-
- Fixpoint df_eval_tree_deriv_genvar {T} (σ:df_env) (df:DefinedFunction EvalAnn T) (v:df_env) : option (definition_function_types_interp T)
- := (match df with
- | Number _ _ => Some 0
- | Constant t _ x => Some
- match t return definition_function_types_interp t with
- | DTfloat => 0
- | DTVector n => ConstVector n 0
- | DTMatrix m n => ConstMatrix m n 0
- end
- | DVector n _ dfs => vectoro_to_ovector (fun i => df_eval_tree_deriv_genvar σ (dfs i) v)
- | DMatrix n m _ df => matrixo_to_omatrix (fun i j => df_eval_tree_deriv_genvar σ (df i j) v)
- | Var x _ => Some (
- match vartlookup v x with
- | Some val => val
- | _ =>
- match (snd x) with
- | DTfloat => 0
- | DTVector n => ConstVector n 0
- | DTMatrix m n => ConstMatrix m n 0
- end
- end)
- | Plus _ l r =>
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some le, Some lr => Some (le + lr)
- | _, _ => None
- end
- | Minus _ l r =>
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some le, Some lr => Some (le - lr)
- | _, _ => None
- end
- | Times _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some ld, Some rd =>
- Some (le * rd +
- (ld * re))
- | _, _ => None
- end
- | Divide _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some ld, Some rd =>
- Some ((ld / re) - ((le * rd) / (re * re)))
- | _, _ => None
- end
- | Square _ e =>
- let ee := get_annotation e in
- match df_eval_tree_deriv_genvar σ e v with
- | Some ed => Some (2 * ee * ed)
- | _ => None
- end
- | Exp _ e =>
- let ee := get_annotation e in
- match df_eval_tree_deriv_genvar σ e v with
- | Some ed => Some (ed * Fexp ee)
- | _ => None
- end
- | Log _ e =>
- let ee := get_annotation e in
- match df_eval_tree_deriv_genvar σ e v with
- | Some ed => Some (ed / ee)
- | _ => None
- end
- | Abs _ e =>
- let ee := get_annotation e in
- match df_eval_tree_deriv_genvar σ e v with
- | Some ed => Some (ed * (sign ee))
- | _ => None
- end
- | Sign _ e =>
- match df_eval_tree_deriv_genvar σ e v with
- | Some _ => Some 0
- | None => None
- end
- | PSign _ e =>
- match df_eval_tree_deriv_genvar σ e v with
- | Some _ => Some 0
- | None => None
- end
- | Max _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- if le <= re then df_eval_tree_deriv_genvar σ r v else df_eval_tree_deriv_genvar σ l v
- | VectorElem n _ l i =>
- match (df_eval_tree_deriv_genvar σ l v) with
- | Some l' => Some (l' i)
- | _ => None
- end
- | MatrixElem m n _ l i j =>
- match (df_eval_tree_deriv_genvar σ l v) with
- | Some l' => Some (l' i j)
- | _ => None
- end
- | VectorDot n _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some ld, Some rd =>
- Some (vsum (fun j => (le j) * (rd j) + (ld j) * (re j)))
- | _, _ => None
- end
- | VectorSum n _ l =>
- match df_eval_tree_deriv_genvar σ l v with
- | Some ld =>
- Some (vsum ld)
- | _ => None
- end
- | MatrixSum n m _ l =>
- match df_eval_tree_deriv_genvar σ l v with
- | Some ld =>
- Some (msum ld)
- | _ => None
- end
- | VectorScalMult n _ x r =>
- let '(xe,re) := (get_annotation x, get_annotation r) in
- match df_eval_tree_deriv_genvar σ x v, df_eval_tree_deriv_genvar σ r v with
- | Some xd, Some rd => Some (fun j => xe * (rd j) + xd * (re j))
- | _, _ => None
- end
- | MatrixScalMult n m _ x r =>
- let '(xe,re) := (get_annotation x, get_annotation r) in
- match df_eval_tree_deriv_genvar σ x v, df_eval_tree_deriv_genvar σ r v with
- | Some xd, Some rd => Some (fun i j => xe * (rd i j) + xd * (re i j))
- | _, _ => None
- end
- | MatrixVectorMult n m _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some ld, Some rd =>
- Some (fun i => vsum (fun j => (le i j)*(rd j) + (ld i j)*(re j)))
- | _, _ => None
- end
- | MatrixVectorAdd n m _ l r =>
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some ld, Some rd =>
- Some (fun i j => (ld i j) + (rd i))
- | _, _ => None
- end
- | MatrixMult n m p _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some ld, Some rd =>
- Some (fun i k => vsum (fun j => (le i j)*(rd j k) + (ld i j)*(re j k)))
- | _, _ => None
- end
- | VectorPlus n _ l r =>
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some l', Some r' => Some (fun i => (l' i) + (r' i))
- | _, _ => None
- end
- | VectorMinus n _ l r =>
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some l', Some r' => Some (fun i => (l' i) - (r' i))
- | _, _ => None
- end
- | MatrixPlus n m _ l r =>
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some l', Some r' => Some (fun i j => (l' i j) + (r' i j))
- | _, _ => None
- end
- | MatrixMinus n m _ l r =>
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some l', Some r' => Some (fun i j => (l' i j) - (r' i j))
- | _, _ => None
- end
- | VectorApply n _ x s r =>
- let re := get_annotation r in
- match df_eval_tree_deriv_genvar σ r v with
- | Some rd =>
- vectoro_to_ovector
- (fun i =>
- let xv := (x, DTfloat):var_type in
- match df_eval_deriv_genvar (cons (mk_env_entry xv (re i)) nil ) s
- (mk_genvar_env x) with
- | Some sd => Some ((rd i) * sd)
- | _ => None
- end)
- | _ => None
- end
- | MatrixApply n m _ x s r =>
- let re := get_annotation r in
- match df_eval_tree_deriv_genvar σ r v with
- | Some rd =>
- matrixo_to_omatrix
- (fun i j =>
- let xv := (x, DTfloat):var_type in
- match df_eval_deriv_genvar (cons (mk_env_entry xv (re i j)) nil) s
- (mk_genvar_env x) with
- | Some sd => Some ((rd i j) * sd)
- | _ => None
- end)
- | _ => None
- end
- | VLossfun n _ v1 v2 s l r =>
- let le := get_annotation l in
- match df_eval_tree_deriv_genvar σ l v with
- | Some ld =>
- match (vectoro_to_ovector
- (fun i =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- match df_eval_deriv_genvar (cons (mk_env_entry xv1 (le i))
- (cons (mk_env_entry xv2 (r i)) nil )) s
- (mk_genvar_env v1) with
- | Some sd => Some ((ld i) * sd)
- | _ => None
- end)) with
- | Some vv => Some (vsum vv)
- | _ => None
- end
- | _ => None
- end
- | MLossfun n m _ v1 v2 s l r =>
- let le := get_annotation l in
- match df_eval_tree_deriv_genvar σ l v with
- | Some ld =>
- match (matrixo_to_omatrix
- (fun i j =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- match df_eval_deriv_genvar (cons (mk_env_entry xv1 (le i j))
- (cons (mk_env_entry xv2 (r i j)) nil)) s
- (mk_genvar_env v1) with
- | Some sd => Some ((ld i j) * sd)
- | _ => None
- end)) with
- | Some vv => Some ((msum vv) / (FfromZ (Z.of_nat m)))
- | _ => None
- end
- | _ => None
- end
- end).
-
- Definition vector_env_iter {n} {A} (f: A -> df_env -> option df_env)
- (env: df_env) (v : Vector A n) : option df_env :=
- vector_fold_right (fun a oenv => match oenv with
- | Some env => f a env
- | _ => None
- end)
- (Some env) v.
-
- Fixpoint list_env_iter {A} (f: A -> df_env -> option df_env)
- (oenv:option df_env) (l: list A) : option df_env :=
- match oenv, l with
- | Some env, x :: l' => list_env_iter f (f x env) l'
- | _, _ => oenv
- end.
-
- Lemma list_env_iter_none {A} (f: A -> df_env -> option df_env) (l: list A) :
- list_env_iter f None l = None.
- Proof.
- induction l.
- now simpl.
- now simpl.
- Qed.
-
- Lemma list_env_iter_env_not_none {A} (f: A -> df_env -> option df_env)
- (oenv : option df_env) (l: list A):
- list_env_iter f oenv l <> None -> oenv <> None.
- Proof.
- intros.
- destruct oenv.
- + discriminate.
- + rewrite list_env_iter_none in H.
- tauto.
- Qed.
-
- Lemma list_env_iter_app {A} (f: A -> df_env -> option df_env)
- (oenv:option df_env) (l1 l2: list A) :
- list_env_iter f oenv (l1++l2) =
- list_env_iter f (list_env_iter f oenv l1) l2.
- Proof.
- revert l2 oenv.
- induction l1; intros l2 oenv; simpl.
- - now destruct oenv.
- - destruct oenv.
- + auto.
- + now rewrite list_env_iter_none.
- Qed.
-
-
- Lemma list_env_iter_ext {A} f1 f2 oenv (l:list A) :
- (forall x a, In x l -> f1 x a = f2 x a) ->
- list_env_iter f1 oenv l = list_env_iter f2 oenv l.
- Proof.
- intros fa.
- revert oenv.
- induction l; intros oenv; intros; simpl
- ; match_destr.
- rewrite fa; simpl; intuition.
- Qed.
-
-
- Definition two_vector_env_iter {n} {A B} (f: A -> B -> df_env -> option df_env)
- (env: df_env) (v: Vector A n) (w: Vector B n) : option df_env :=
- vector_env_iter (fun '(a,b) env => f a b env) env
- (vector_zip v w).
-
-
- Definition two_vector_env_iter_alt {n} {A B} (f: A -> B -> df_env -> option df_env)
- (env: df_env) (v: Vector A n) (w: Vector B n) : option df_env :=
- list_env_iter (fun i env => f (v i) (w i) env) (Some env) (bounded_seq0 n).
-
- Definition matrix_env_iter {m n} {A} (f: A -> df_env -> option df_env)
- (env: option df_env) (mat : Matrix A m n) : option df_env :=
- vector_fold_right
- (fun vec oenv =>
- vector_fold_right (fun a oenv => match oenv with
- | Some env => f a env
- | _ => None
- end) oenv vec
- ) env mat.
-
- Definition two_matrix_env_iter {n m} {A B} (f: A -> B -> df_env -> option df_env)
- (env: option df_env) (v: Matrix A n m) (w: Matrix B n m) : option df_env :=
- let vw := matrix_zip v w in
- matrix_env_iter (fun '(a,b) e => f a b e) env vw.
-
- Definition two_matrix_env_iter_alt {n m} {A B} (f: A -> B -> df_env -> option df_env)
- (env: df_env) (v: Matrix A n m) (w: Matrix B n m) : option df_env :=
- list_env_iter (fun i env => list_env_iter (fun j env => f (v i j) (w i j) env)
- (Some env) (bounded_seq0 m))
- (Some env) (bounded_seq0 n).
-
-
- Program Definition addvar (x : var_type) (grad_env:df_env) :=
- (match snd x as y return snd x = y ->
- definition_function_types_interp y ->
- definition_function_types_interp y with
- | DTfloat => fun pf grad => match vartlookup grad_env x with
- | Some val => grad + ((coerce _ val):float)
- | _ => grad
- end
- | DTVector n => fun pf grad => match vartlookup grad_env x with
- | Some val => fun i => (grad i) + (((coerce _ val):Vector float n) i)
- | _ => grad
- end
- | DTMatrix m n => fun pf grad => match vartlookup grad_env x with
- | Some val => fun i j => (((coerce _ val):Matrix float m n) i j) + (grad i j)
- | _ => grad
- end
- end) (eq_refl _).
- Next Obligation.
- rewrite pf; reflexivity.
- Defined.
- Next Obligation.
- rewrite pf; reflexivity.
- Defined.
- Next Obligation.
- rewrite pf; reflexivity.
- Defined.
-
- Definition gradenv_init1 (v : var_type) : env_entry_type :=
- mk_env_entry v
- (match snd v as y return definition_function_types_interp y with
- | DTfloat => 0
- | DTVector n => ConstVector n 0
- | DTMatrix n m => ConstMatrix n m 0
- end).
-
- Definition gradenv_init (dvars : list var_type) : df_env :=
- map gradenv_init1 dvars.
-
- Fixpoint df_eval_backprop_deriv {T Ann} (σ:df_env) (df:DefinedFunction Ann T) (grad_env:df_env) {struct df} : definition_function_types_interp T -> option df_env
- := match df with
- | Number _ _ => fun grad => Some grad_env
- | Constant _ _ _ => fun grad => Some grad_env
- | DVector n _ dfs => fun grad =>
- two_vector_env_iter_alt (fun x g genv => df_eval_backprop_deriv σ x genv g)
- grad_env dfs grad
- | DMatrix n m _ dfs => fun grad =>
- two_matrix_env_iter_alt (fun x g genv => df_eval_backprop_deriv σ x genv g)
- grad_env dfs grad
- | Var x _ => fun grad =>
- if vartlookup grad_env x then
- Some (vart_update grad_env x (addvar x grad_env grad))
- else Some grad_env
- | Plus _ l r => fun grad =>
- match df_eval_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_backprop_deriv σ r grad_env' grad
- | _ => None
- end
- | Minus _ l r => fun grad =>
- match df_eval_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (-grad)
- | _ => None
- end
- | Times _ l r => fun grad =>
- match df_eval σ l, df_eval σ r with
- | Some le, Some re =>
- match df_eval_backprop_deriv σ l grad_env (re * grad) with
- | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (le * grad)
- | _ => None
- end
- | _, _ => None
- end
- | Divide _ l r => fun grad =>
- match df_eval σ l, df_eval σ r with
- | Some le, Some re =>
- match df_eval_backprop_deriv σ l grad_env (grad / re) with
- | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (- le / (re * re) * grad)
- | _ => None
- end
- | _, _ => None
- end
- | Square _ e => fun grad =>
- match df_eval σ e with
- | Some ee => df_eval_backprop_deriv σ e grad_env (2 * ee * grad)
- | _ => None
- end
- | Exp _ e => fun grad =>
- match df_eval σ e with
- | Some ee => df_eval_backprop_deriv σ e grad_env (grad * Fexp ee)
- | _ => None
- end
- | Log _ e => fun grad =>
- match df_eval σ e with
- | Some ee => df_eval_backprop_deriv σ e grad_env (grad / ee)
- | _ => None
- end
- | Abs _ e => fun grad =>
- match df_eval σ e with
- | Some ee => df_eval_backprop_deriv σ e grad_env (grad * (sign ee))
- | _ => None
- end
- | Sign _ e => fun grad => df_eval_backprop_deriv σ e grad_env 0
- | PSign _ e => fun grad => df_eval_backprop_deriv σ e grad_env 0
- | Max _ l r => fun grad =>
- match df_eval σ l,
- df_eval σ r with
- | Some le, Some re =>
- if le <= re then
- (df_eval_backprop_deriv σ r grad_env grad) else
- (df_eval_backprop_deriv σ l grad_env grad)
- | _, _ => None
- end
- | VectorElem n _ l i => fun grad =>
- let grad' := fun k => if proj1_sig k == proj1_sig i then grad else 0 in
- df_eval_backprop_deriv σ l grad_env grad'
- | MatrixElem m n _ l i j => fun grad =>
- let grad' := fun k1 k2 =>
- if (proj1_sig k1 == proj1_sig i) then
- if (proj1_sig k2 == proj1_sig j) then grad else 0
- else 0 in
- df_eval_backprop_deriv σ l grad_env grad'
- | VectorDot n _ l r => fun grad =>
- match df_eval σ l, df_eval σ r with
- | Some le, Some re =>
- match df_eval_backprop_deriv σ l grad_env (vmap (fun rv => rv*grad) re) with
- | Some grad_env' =>
- df_eval_backprop_deriv σ r grad_env' (vmap (fun lv => lv*grad) le)
- | _ => None
- end
- | _, _ => None
- end
- | VectorSum n _ l => fun grad =>
- df_eval_backprop_deriv σ l grad_env (ConstVector n grad)
- | MatrixSum n m _ l => fun grad =>
- df_eval_backprop_deriv σ l grad_env (ConstMatrix n m grad)
- | VectorScalMult n _ x r => fun grad =>
- match df_eval σ x, df_eval σ r with
- | Some xe, Some re =>
- match df_eval_backprop_deriv σ x grad_env (vsum (fun j => (re j) * (grad j))) with
- | Some grad_env' =>
- df_eval_backprop_deriv σ r grad_env' (fun j => xe * (grad j))
- | _ => None
- end
- | _, _ => None
- end
- | MatrixScalMult n m _ x r => fun grad =>
- match df_eval σ x, df_eval σ r with
- | Some xe, Some re =>
- match df_eval_backprop_deriv σ x grad_env (msum (fun i j => (re i j) * (grad i j))) with
- | Some grad_env' =>
- df_eval_backprop_deriv σ r grad_env' (fun i j => (grad i j) * xe)
- | _ => None
- end
- | _, _ => None
- end
- | MatrixVectorMult n m _ l r => fun grad =>
- match df_eval σ l, df_eval σ r with
- | Some le, Some re =>
- match df_eval_backprop_deriv σ l grad_env (fun i j => (grad i) * (re j)) with
-
-
- | Some grad_env' =>
- df_eval_backprop_deriv σ r grad_env' (matrix_vector_mult (fun i j => le j i) grad)
- | _ => None
- end
- | _, _ => None
- end
- | MatrixVectorAdd n m _ l r => fun grad =>
- match df_eval_backprop_deriv σ l grad_env grad with
- | Some grad_env' =>
- match list_env_iter
- (fun i env => df_eval_backprop_deriv σ r env ((transpose grad) i))
- (Some grad_env') (bounded_seq0 m) with
- | Some grad_env'' => Some grad_env''
- | _ => None
- end
- | _ => None
- end
- | MatrixMult n m p _ l r => fun grad =>
- match df_eval σ l, df_eval σ r with
- | Some le, Some re =>
- match df_eval_backprop_deriv σ l grad_env (matrix_mult grad (fun i j => (re j i))) with
-
-
- | Some grad_env' =>
- df_eval_backprop_deriv σ r grad_env' (matrix_mult (fun i j => le j i) grad)
- | _ => None
- end
- | _, _ => None
- end
- | VectorPlus n _ l r => fun grad =>
- match df_eval_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_backprop_deriv σ r grad_env' grad
- | _ => None
- end
- | VectorMinus n _ l r => fun grad =>
- match df_eval_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (fun i => - (grad i))
- | _ => None
- end
- | MatrixPlus n m _ l r => fun grad =>
- match df_eval_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_backprop_deriv σ r grad_env' grad
- | _ => None
- end
- | MatrixMinus n m _ l r => fun grad =>
- match df_eval_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (fun i j => - (grad i j))
- | _ => None
- end
- | VectorApply n _ x s r => fun grad =>
- match df_eval σ r with
- | Some re =>
- let xv := (x, DTfloat):var_type in
- let s' := df_deriv s xv in
- let ograd :=
- vmap (fun '(rei, g) =>
-
- match df_eval (cons (mk_env_entry xv rei) nil) s' with
- | Some se => Some (g * se)
- | _ => None
- end)
- (vector_zip re grad) in
- match vectoro_to_ovector ograd with
- | Some grad' => df_eval_backprop_deriv σ r grad_env grad'
- | _ => None
- end
- | _ => None
- end
- | MatrixApply n m _ x s r => fun grad =>
- match df_eval σ r with
- | Some re =>
- let xv := (x, DTfloat):var_type in
- let s' := df_deriv s xv in
- let ograd :=
- mmap (fun '(rei, g) =>
- match df_eval (cons (mk_env_entry xv rei) nil) s' with
- | Some se => Some (g * se)
- | _ => None
- end)
- (matrix_zip re grad) in
- match matrixo_to_omatrix ograd with
- | Some grad' => df_eval_backprop_deriv σ r grad_env grad'
- | _ => None
- end
- | _ => None
- end
- | VLossfun n _ v1 v2 s l re => fun grad =>
- match df_eval σ l with
- | Some le =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- let s' := df_deriv s xv1 in
- let ograd :=
- vmap (fun '(lei, rei) =>
- let senv := cons (mk_env_entry xv1 lei)
- (cons (mk_env_entry xv2 rei) nil) in
- match df_eval senv s' with
- | Some se => Some (grad * se)
- | _ => None
- end)
- (vector_zip le re) in
- match vectoro_to_ovector ograd with
- | Some grad' => df_eval_backprop_deriv σ l grad_env grad'
- | _ => None
- end
- | _ => None
- end
- | MLossfun n m _ v1 v2 s l re => fun grad =>
- match df_eval σ l with
- | Some le =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- let s' := df_deriv s xv1 in
- let ograd :=
- mmap (fun '(lei, rei) =>
- let senv := cons (mk_env_entry xv1 lei)
- (cons (mk_env_entry xv2 rei) nil) in
- match df_eval senv s' with
- | Some se => Some ((grad * se)/(FfromZ (Z.of_nat m)))
- | _ => None
- end)
- (matrix_zip le re) in
- match matrixo_to_omatrix ograd with
- | Some grad' => df_eval_backprop_deriv σ l grad_env grad'
- | _ => None
- end
- | _ => None
- end
- end.
-
- Definition lifted_type (B:Type) T
- := match T with
- | DTfloat => B
- | DTVector n => Vector B n
- | DTMatrix m n => Matrix B m n
- end.
-
- Fixpoint df_eval_tree_backprop_deriv {T} (σ:df_env) (df:DefinedFunction EvalAnn T) (grad_env:df_env) {struct df} : definition_function_types_interp T -> option df_env
- := match df with
- | Number _ _ => fun grad => Some grad_env
- | Constant _ _ _ => fun grad => Some grad_env
- | DVector n _ dfs => fun grad =>
- two_vector_env_iter_alt (fun x g genv => df_eval_tree_backprop_deriv σ x genv g)
- grad_env dfs grad
- | DMatrix n m _ dfs => fun grad =>
- two_matrix_env_iter_alt (fun x g genv => df_eval_tree_backprop_deriv σ x genv g)
- grad_env dfs grad
- | Var x _ => fun grad =>
- if vartlookup grad_env x then
- Some (vart_update grad_env x (addvar x grad_env grad))
- else Some grad_env
- | Plus _ l r => fun grad =>
- match df_eval_tree_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' grad
- | _ => None
- end
- | Minus _ l r => fun grad =>
- match df_eval_tree_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (-grad)
- | _ => None
- end
- | Times _ l r => fun grad =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_backprop_deriv σ l grad_env (re * grad) with
- | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (le * grad)
- | _ => None
- end
- | Divide _ l r => fun grad =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_backprop_deriv σ l grad_env (grad / re) with
- | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (- le / (re * re) * grad)
- | _ => None
- end
- | Square _ e => fun grad =>
- let ee := get_annotation e in
- df_eval_tree_backprop_deriv σ e grad_env (2 * ee * grad)
- | Exp _ e => fun grad =>
- let ee := get_annotation e in
- df_eval_tree_backprop_deriv σ e grad_env (grad * Fexp ee)
- | Log _ e => fun grad =>
- let ee := get_annotation e in
- df_eval_tree_backprop_deriv σ e grad_env (grad / ee)
- | Abs _ e => fun grad =>
- let ee := get_annotation e in
- df_eval_tree_backprop_deriv σ e grad_env (grad * (sign ee))
- | Sign _ e => fun grad => df_eval_tree_backprop_deriv σ e grad_env 0
- | PSign _ e => fun grad => df_eval_tree_backprop_deriv σ e grad_env 0
- | Max _ l r => fun grad =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- if le <= re then
- (df_eval_tree_backprop_deriv σ r grad_env grad) else
- (df_eval_tree_backprop_deriv σ l grad_env grad)
- | VectorElem n _ l i => fun grad =>
- let grad' := fun k => if proj1_sig k == proj1_sig i then grad else 0 in
- df_eval_tree_backprop_deriv σ l grad_env grad'
- | MatrixElem m n _ l i j => fun grad =>
- let grad' := fun k1 k2 =>
- if (proj1_sig k1 == proj1_sig i) then
- if (proj1_sig k2 == proj1_sig j) then grad else 0
- else 0 in
- df_eval_tree_backprop_deriv σ l grad_env grad'
- | VectorDot n _ l r => fun grad =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_backprop_deriv σ l grad_env (vmap (fun rv => rv*grad) re) with
- | Some grad_env' =>
- df_eval_tree_backprop_deriv σ r grad_env' (vmap (fun lv => lv*grad) le)
- | _ => None
- end
- | VectorSum n _ l => fun grad =>
- df_eval_tree_backprop_deriv σ l grad_env (ConstVector n grad)
- | MatrixSum n m _ l => fun grad =>
- df_eval_tree_backprop_deriv σ l grad_env (ConstMatrix n m grad)
- | VectorScalMult n _ x r => fun grad =>
- let '(xe,re) := (get_annotation x, get_annotation r) in
- match df_eval_tree_backprop_deriv σ x grad_env (vsum (fun j => (re j) * (grad j))) with
- | Some grad_env' =>
- df_eval_tree_backprop_deriv σ r grad_env' (fun j => xe * (grad j))
- | _ => None
- end
- | MatrixScalMult n m _ x r => fun grad =>
- let '(xe,re) := (get_annotation x, get_annotation r) in
- match df_eval_tree_backprop_deriv σ x grad_env (msum (fun i j => (re i j) * (grad i j))) with
- | Some grad_env' =>
- df_eval_tree_backprop_deriv σ r grad_env' (fun i j => (grad i j) * xe)
- | _ => None
- end
- | MatrixVectorMult n m _ l r => fun grad =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_backprop_deriv σ l grad_env (fun i j => (grad i) * (re j)) with
- | Some grad_env' =>
- df_eval_tree_backprop_deriv σ r grad_env' (matrix_vector_mult (fun i j => le j i) grad)
- | _ => None
- end
- | MatrixVectorAdd n m _ l r => fun grad =>
- match df_eval_tree_backprop_deriv σ l grad_env grad with
- | Some grad_env' =>
- match list_env_iter
- (fun i env => df_eval_tree_backprop_deriv σ r env ((transpose grad) i))
- (Some grad_env') (bounded_seq0 m) with
- | Some grad_env'' => Some grad_env''
- | _ => None
- end
- | _ => None
- end
- | MatrixMult n m p _ l r => fun grad =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_backprop_deriv σ l grad_env (matrix_mult grad (fun i j => (re j i))) with
- | Some grad_env' =>
- df_eval_tree_backprop_deriv σ r grad_env' (matrix_mult (fun i j => le j i) grad)
- | _ => None
- end
- | VectorPlus n _ l r => fun grad =>
- match df_eval_tree_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' grad
- | _ => None
- end
- | VectorMinus n _ l r => fun grad =>
- match df_eval_tree_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (fun i => - (grad i))
- | _ => None
- end
- | MatrixPlus n m _ l r => fun grad =>
- match df_eval_tree_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' grad
- | _ => None
- end
- | MatrixMinus n m _ l r => fun grad =>
- match df_eval_tree_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (fun i j => - (grad i j))
- | _ => None
- end
- | VectorApply n _ x s r => fun grad =>
- let re := get_annotation r in
- let xv := (x, DTfloat):var_type in
- let s' := df_deriv s xv in
- let ograd :=
- vmap (fun '(rei, g) =>
- match df_eval (cons (mk_env_entry xv rei) nil) s' with
- | Some se => Some (g * se)
- | _ => None
- end)
- (vector_zip re grad) in
- match vectoro_to_ovector ograd with
- | Some grad' => df_eval_tree_backprop_deriv σ r grad_env grad'
- | _ => None
- end
- | MatrixApply n m _ x s r => fun grad =>
- let re := get_annotation r in
- let xv := (x, DTfloat):var_type in
- let s' := df_deriv s xv in
- let ograd :=
- mmap (fun '(rei, g) =>
- match df_eval (cons (mk_env_entry xv rei) nil) s' with
- | Some se => Some (g * se)
- | _ => None
- end)
- (matrix_zip re grad) in
- match matrixo_to_omatrix ograd with
- | Some grad' => df_eval_tree_backprop_deriv σ r grad_env grad'
- | _ => None
- end
- | VLossfun n _ v1 v2 s l re => fun grad =>
- let le := get_annotation l in
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- let s' := df_deriv s xv1 in
- let ograd :=
- vmap (fun '(lei, rei) =>
- let senv := cons (mk_env_entry xv1 lei)
- (cons (mk_env_entry xv2 rei) nil) in
- match df_eval senv s' with
- | Some se => Some (grad * se)
- | _ => None
- end)
- (vector_zip le re) in
- match vectoro_to_ovector ograd with
- | Some grad' => df_eval_tree_backprop_deriv σ l grad_env grad'
- | _ => None
- end
- | MLossfun n m _ v1 v2 s l re => fun grad =>
- let le := get_annotation l in
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- let s' := df_deriv s xv1 in
- let ograd :=
- mmap (fun '(lei, rei) =>
- let senv := cons (mk_env_entry xv1 lei)
- (cons (mk_env_entry xv2 rei) nil) in
- match df_eval senv s' with
- | Some se => Some ((grad * se) / (FfromZ (Z.of_nat m)))
- | _ => None
- end)
- (matrix_zip le re) in
- match matrixo_to_omatrix ograd with
- | Some grad' => df_eval_tree_backprop_deriv σ l grad_env grad'
- | _ => None
- end
- end.
-
- Definition o_df_env_to_df_env (oenv : option df_env) : df_env :=
- match oenv with
- | Some env => env
- | _ => nil
- end.
-
-
- Definition backprop_lookup (oenv:option df_env) (a:var_type) :
- option (definition_function_types_interp (snd a)) :=
- match oenv with
- | Some env =>
- match vartlookup env a with
- | Some val => Some val
- | _ => None
- end
- | _ => None
- end.
-
- Definition is_scalar_df_type (dft:definition_function_types) : Prop
- := match dft with
- | DTfloat => True
- | _ => False
- end.
-
- Fixpoint is_scalar_function {Ann} {T} (df:DefinedFunction Ann T) : Prop
- := match df with
- | Number _ _ => True
- | Constant t _ _ => is_scalar_df_type t
- | Var v _ => is_scalar_df_type (snd v)
- | Plus _ l r => is_scalar_function l /\ is_scalar_function r
- | Minus _ l r => is_scalar_function l /\ is_scalar_function r
- | Times _ l r => is_scalar_function l /\ is_scalar_function r
- | Divide _ l r => is_scalar_function l /\ is_scalar_function r
- | Square _ e => is_scalar_function e
- | Exp _ e => is_scalar_function e
- | Log _ e => is_scalar_function e
- | Abs _ e => is_scalar_function e
- | Sign _ e => is_scalar_function e
- | PSign _ e => is_scalar_function e
- | Max _ l r => is_scalar_function l /\ is_scalar_function r
- | _ => False
- end.
-
- Fixpoint has_scalar_functions {Ann} {T}
- (df:DefinedFunction Ann T) {struct df}: Prop
- := match df with
- | Number _ _ => True
- | Constant _ _ _ => True
- | DVector n _ vec =>
- vforall (has_scalar_functions) vec
- | DMatrix n m _ mat =>
- vforall (vforall (has_scalar_functions)) mat
- | Var _ _ => True
- | Plus _ l r => (has_scalar_functions l) /\ (has_scalar_functions r)
- | Minus _ l r => (has_scalar_functions l) /\ (has_scalar_functions r)
- | Times _ l r => (has_scalar_functions l) /\ (has_scalar_functions r)
- | Divide _ l r => (has_scalar_functions l) /\ (has_scalar_functions r)
- | Square _ l => has_scalar_functions l
- | Exp _ l => has_scalar_functions l
- | Log _ l => has_scalar_functions l
- | Abs _ l => has_scalar_functions l
- | Sign _ l => has_scalar_functions l
- | PSign _ l => has_scalar_functions l
- | Max _ l r => (has_scalar_functions l) /\ (has_scalar_functions r)
- | VectorDot n _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | VectorSum _ _ l => has_scalar_functions l
- | MatrixSum _ _ _ l => has_scalar_functions l
- | VectorElem _ _ vec i => has_scalar_functions vec
- | MatrixElem _ _ _ mat i j => has_scalar_functions mat
- | MatrixVectorMult _ _ _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | MatrixVectorAdd _ _ _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | MatrixMult _ _ _ _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | VectorPlus _ _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | VectorMinus _ _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | MatrixPlus _ _ _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | MatrixMinus _ _ _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | VectorScalMult _ _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | MatrixScalMult _ _ _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | VectorApply _ _ _ s l => is_scalar_function s /\ has_scalar_functions l
- | MatrixApply _ _ _ _ s l => is_scalar_function s /\ has_scalar_functions l
- | VLossfun _ _ _ _ s l _ => is_scalar_function s /\ has_scalar_functions l
- | MLossfun _ _ _ _ _ s l _ => is_scalar_function s /\ has_scalar_functions l
- end.
-
- Lemma is_scalar_function_has_scalar_functions {Ann} {T} (df:DefinedFunction Ann T) :
- is_scalar_function df -> has_scalar_functions df.
- Proof.
- induction df; firstorder.
- Qed.
-
- Hint Resolve is_scalar_function_has_scalar_functions : fml.
-
- Definition DefinedFunction_ind_unit_has_scalar_functions
- (P : forall (d : definition_function_types), DefinedFunction UnitAnn d -> Prop)
- (f : forall (ann : UnitAnn DTfloat) (x : float),
- P DTfloat (Number ann x))
- (f0 : forall (t : definition_function_types)
- (ann : UnitAnn t) (x : definition_function_types_interp t), P t (Constant ann x))
- (f1 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (x : Vector (DefinedFunction UnitAnn DTfloat) n),
- (forall s : {n' : nat | (n' < n)%nat}, P DTfloat (x s)) ->
- P (DTVector n) (DVector ann x))
- (f2 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m))
- (x : Matrix (DefinedFunction UnitAnn DTfloat) n m),
- (forall (s : {n' : nat | (n' < n)%nat}) (s0 : {m' : nat | (m' < m)%nat}),
- P DTfloat (x s s0)) -> P (DTMatrix n m) (DMatrix ann x))
- (f3 : forall (v : var_type) (ann : UnitAnn (snd v)),
- P (snd v) (Var v ann))
- (f4 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Plus ann l r))
- (f5 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Minus ann l r))
- (f6 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Times ann l r))
- (f7 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Divide ann l r))
- (f8 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Square ann e))
- (f9 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Exp ann e))
- (f10 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Log ann e))
- (f11 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Abs ann e))
- (f12 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Sign ann e))
- (f13 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (PSign ann e))
- (f14 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Max ann l r))
- (f15 : forall (n : nat) (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn (DTVector n)),
- P (DTVector n) l ->
- forall r : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) r -> P DTfloat (VectorDot ann l r))
- (f16 : forall (n : nat) (ann : UnitAnn DTfloat)
- (v : DefinedFunction UnitAnn (DTVector n)), P (DTVector n) v -> P DTfloat (VectorSum ann v))
- (f17 : forall (m n : nat) (ann : UnitAnn DTfloat)
- (v : DefinedFunction UnitAnn (DTMatrix m n)),
- P (DTMatrix m n) v -> P DTfloat (MatrixSum ann v))
- (f18 : forall (n : nat) (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn (DTVector n)),
- P (DTVector n) l ->
- forall i : {x : nat | (x < n)%nat}, P DTfloat (VectorElem ann l i))
- (f19 : forall (m n : nat) (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn (DTMatrix m n)),
- P (DTMatrix m n) l ->
- forall (i : {x : nat | (x < m)%nat}) (j : {x : nat | (x < n)%nat}),
- P DTfloat (MatrixElem ann l i j))
- (f20 : forall (m n : nat) (ann : UnitAnn (DTVector m))
- (l : DefinedFunction UnitAnn (DTMatrix m n)),
- P (DTMatrix m n) l ->
- forall r : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) r -> P (DTVector m) (MatrixVectorMult ann l r))
- (f21 : forall (m n : nat) (ann : UnitAnn (DTMatrix m n))
- (l : DefinedFunction UnitAnn (DTMatrix m n)),
- P (DTMatrix m n) l ->
- forall r : DefinedFunction UnitAnn (DTVector m),
- P (DTVector m) r -> P (DTMatrix m n) (MatrixVectorAdd ann l r))
- (f22 : forall (m p n : nat) (ann : UnitAnn (DTMatrix m n))
- (l : DefinedFunction UnitAnn (DTMatrix m p)),
- P (DTMatrix m p) l ->
- forall r : DefinedFunction UnitAnn (DTMatrix p n),
- P (DTMatrix p n) r -> P (DTMatrix m n) (MatrixMult ann l r))
- (f23 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (l : DefinedFunction UnitAnn (DTVector n)),
- P (DTVector n) l ->
- forall r : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) r -> P (DTVector n) (VectorPlus ann l r))
- (f24 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (l : DefinedFunction UnitAnn (DTVector n)),
- P (DTVector n) l ->
- forall r : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) r -> P (DTVector n) (VectorMinus ann l r))
- (f25 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m))
- (l : DefinedFunction UnitAnn (DTMatrix n m)),
- P (DTMatrix n m) l ->
- forall r : DefinedFunction UnitAnn (DTMatrix n m),
- P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixPlus ann l r))
- (f26 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m))
- (l : DefinedFunction UnitAnn (DTMatrix n m)),
- P (DTMatrix n m) l ->
- forall r : DefinedFunction UnitAnn (DTMatrix n m),
- P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixMinus ann l r))
- (f27 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (x : DefinedFunction UnitAnn DTfloat),
- P DTfloat x ->
- forall l : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) l -> P (DTVector n) (VectorScalMult ann x l))
- (f28 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m))
- (x : DefinedFunction UnitAnn DTfloat),
- P DTfloat x ->
- forall l : DefinedFunction UnitAnn (DTMatrix n m),
- P (DTMatrix n m) l -> P (DTMatrix n m) (MatrixScalMult ann x l))
- (f29 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (v : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- is_scalar_function s ->
- P DTfloat s ->
- forall l : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) l -> P (DTVector n) (VectorApply ann v s l))
- (f30 : forall (m n : nat) (ann : UnitAnn (DTMatrix m n))
- (v : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- is_scalar_function s ->
- P DTfloat s ->
- forall l : DefinedFunction UnitAnn (DTMatrix m n),
- P (DTMatrix m n) l -> P (DTMatrix m n) (MatrixApply ann v s l))
- (f31 : forall (n : nat) (ann : UnitAnn DTfloat)
- (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- is_scalar_function s ->
- P DTfloat s ->
- forall l : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) l -> forall r : Vector float n, P DTfloat (VLossfun ann v1 v2 s l r))
- (f32 : forall (m n : nat) (ann : UnitAnn DTfloat)
- (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- is_scalar_function s ->
- P DTfloat s ->
- forall l : DefinedFunction UnitAnn (DTMatrix m n),
- P (DTMatrix m n) l ->
- forall r : Matrix float m n, P DTfloat (MLossfun ann v1 v2 s l r))
- : forall (d : definition_function_types)
- (d0 : DefinedFunction UnitAnn d)
- (hs:has_scalar_functions d0), P d d0.
- Proof.
- refine (fix
- F (d : definition_function_types)
- (d0 : DefinedFunction UnitAnn d)
- {struct d0} : has_scalar_functions d0 -> P d d0 :=
- match d0 as d2 in (DefinedFunction _ d1) return has_scalar_functions d2 -> (P d1 d2) with
- | Number ann x => fun hs => f ann x
- | @Constant _ t ann x => fun hs => f0 t ann x
- | @DVector _ n ann x => fun hs => f1 n ann x (fun s : {n' : nat | (n' < n)%nat} => F DTfloat (x s) _)
- | @DMatrix _ n m ann x => fun hs =>
- f2 n m ann x
- (fun (s : {n' : nat | (n' < n)%nat}) (s0 : {m' : nat | (m' < m)%nat}) =>
- F DTfloat (x s s0) _)
- | Var v ann => fun hs => f3 v ann
- | Plus ann l r => fun hs => f4 ann l (F DTfloat l (proj1 hs)) r (F DTfloat r (proj2 hs))
- | Minus ann l r => fun hs => f5 ann l (F DTfloat l _) r (F DTfloat r _)
- | Times ann l r => fun hs => f6 ann l (F DTfloat l _) r (F DTfloat r _)
- | Divide ann l r => fun hs => f7 ann l (F DTfloat l _) r (F DTfloat r _)
- | Square ann e => fun hs => f8 ann e (F DTfloat e _)
- | Exp ann e => fun hs => f9 ann e (F DTfloat e _)
- | Log ann e => fun hs => f10 ann e (F DTfloat e _)
- | Abs ann e => fun hs => f11 ann e (F DTfloat e _)
- | Sign ann e => fun hs => f12 ann e (F DTfloat e _)
- | PSign ann e => fun hs => f13 ann e (F DTfloat e _)
- | Max ann l r => fun hs => f14 ann l (F DTfloat l _) r (F DTfloat r _)
- | @VectorDot _ n ann l r => fun hs => f15 n ann l (F (DTVector n) l _) r (F (DTVector n) r _)
- | @VectorSum _ n ann v => fun hs => f16 n ann v (F (DTVector n) v _)
- | @MatrixSum _ m n ann v => fun hs => f17 m n ann v (F (DTMatrix m n) v _)
- | @VectorElem _ n ann l i => fun hs => f18 n ann l (F (DTVector n) l _) i
- | @MatrixElem _ m n ann l i j => fun hs => f19 m n ann l (F (DTMatrix m n) l _) i j
- | @MatrixVectorMult _ m n ann l r => fun hs =>
- f20 m n ann l (F (DTMatrix m n) l _) r (F (DTVector n) r _)
- | @MatrixVectorAdd _ m n ann l r => fun hs =>
- f21 m n ann l (F (DTMatrix m n) l _) r (F (DTVector m) r _)
- | @MatrixMult _ m p n ann l r => fun hs =>
- f22 m p n ann l (F (DTMatrix m p) l _) r (F (DTMatrix p n) r _)
- | @VectorPlus _ n ann l r => fun hs => f23 n ann l (F (DTVector n) l _) r (F (DTVector n) r _)
- | @VectorMinus _ n ann l r => fun hs => f24 n ann l (F (DTVector n) l _) r (F (DTVector n) r _)
- | @MatrixPlus _ n m ann l r => fun hs => f25 n m ann l (F (DTMatrix n m) l _) r (F (DTMatrix n m) r _)
- | @MatrixMinus _ n m ann l r => fun hs =>
- f26 n m ann l (F (DTMatrix n m) l _) r (F (DTMatrix n m) r _)
- | @VectorScalMult _ n ann x l => fun hs => f27 n ann x (F DTfloat x _) l (F (DTVector n) l _)
- | @MatrixScalMult _ n m ann x l => fun hs => f28 n m ann x (F DTfloat x _) l (F (DTMatrix n m) l _)
- | @VectorApply _ n ann v s l => fun hs => f29 n ann v s _ (F DTfloat s _) l (F (DTVector n) l _)
- | @MatrixApply _ m n ann v s l => fun hs =>
- f30 m n ann v s _ (F DTfloat s _) l (F (DTMatrix m n) l _)
- | @VLossfun _ n ann v1 v2 s l r => fun hs =>
- f31 n ann v1 v2 s _ (F DTfloat s _) l (F (DTVector n) l _) r
- | @MLossfun _ m n ann v1 v2 s l r => fun hs =>
- f32 m n ann v1 v2 s _ (F DTfloat s _) l (F (DTMatrix m n) l _) r
- end); simpl in hs; intuition.
- - exact (proj1 (vforall_forall has_scalar_functions x) hs s).
- - rewrite vforall_forall in hs.
- specialize (hs s).
- rewrite vforall_forall in hs.
- specialize (hs s0).
- exact hs.
- Defined.
-
- Fixpoint is_df_rec_prop {Ann} {T}
- (prop : forall TT:definition_function_types,
- (DefinedFunction Ann TT) -> Prop)
- (df:DefinedFunction Ann T) {struct df}: Prop
- := prop T df /\
- match df with
- | Number _ _ => True
- | Constant _ _ _ => True
- | DVector n _ vec =>
- vforall (is_df_rec_prop prop) vec
- | DMatrix n m _ mat =>
- vforall (vforall (is_df_rec_prop prop)) mat
- | Var _ _ => True
- | Plus _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r)
- | Minus _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r)
- | Times _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r)
- | Divide _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r)
- | Square _ l => is_df_rec_prop prop l
- | Exp _ l => is_df_rec_prop prop l
- | Log _ l => is_df_rec_prop prop l
- | Abs _ l => is_df_rec_prop prop l
- | Sign _ l => is_df_rec_prop prop l
- | PSign _ l => is_df_rec_prop prop l
- | Max _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r)
- | VectorDot n _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | VectorSum _ _ l => is_df_rec_prop prop l
- | MatrixSum _ _ _ l => is_df_rec_prop prop l
- | VectorElem _ _ vec i => is_df_rec_prop prop vec
- | MatrixElem _ _ _ mat i j => is_df_rec_prop prop mat
- | MatrixVectorMult _ _ _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | MatrixVectorAdd _ _ _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | MatrixMult _ _ _ _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | VectorPlus _ _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | VectorMinus _ _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | MatrixPlus _ _ _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | MatrixMinus _ _ _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | VectorScalMult _ _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | MatrixScalMult _ _ _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | VectorApply _ _ _ _ l => is_df_rec_prop prop l
- | MatrixApply _ _ _ _ _ l => is_df_rec_prop prop l
- | VLossfun _ _ _ _ _ l _ => is_df_rec_prop prop l
- | MLossfun _ _ _ _ _ _ l _ => is_df_rec_prop prop l
- end.
-
- Fixpoint df_strip_annotations {Ann} {T}
- (df:DefinedFunction Ann T) {struct df}: DefinedFunction UnitAnn T
- :=
- match df with
- | Number _ x1 => Number tt x1
- | Constant t _ x => Constant tt x
- | DVector n _ vec => DVector tt (vmap df_strip_annotations vec)
- | DMatrix n m _ mat => DMatrix tt (vmap (vmap df_strip_annotations) mat)
- | Var v _ => Var v tt
- | Plus _ l r => Plus tt (df_strip_annotations l) (df_strip_annotations r)
- | Minus _ l r => Minus tt (df_strip_annotations l) (df_strip_annotations r)
- | Times _ l r => Times tt (df_strip_annotations l) (df_strip_annotations r)
- | Divide _ l r => Divide tt (df_strip_annotations l) (df_strip_annotations r)
- | Square _ l => Square tt (df_strip_annotations l)
- | Exp _ l => Exp tt (df_strip_annotations l)
- | Log _ l => Log tt (df_strip_annotations l)
- | Abs _ l => Abs tt (df_strip_annotations l)
- | Sign _ l => Sign tt (df_strip_annotations l)
- | PSign _ l => PSign tt (df_strip_annotations l)
- | Max _ l r => Max tt (df_strip_annotations l) (df_strip_annotations r)
- | VectorDot n _ l r => VectorDot tt (df_strip_annotations l) (df_strip_annotations r)
- | VectorSum n _ l => VectorSum tt (df_strip_annotations l)
- | MatrixSum m n _ l => MatrixSum tt (df_strip_annotations l)
- | VectorElem n _ vec i => VectorElem tt (df_strip_annotations vec) i
- | MatrixElem m n _ mat i j => MatrixElem tt (df_strip_annotations mat) i j
- | MatrixVectorMult m n _ l r => MatrixVectorMult tt (df_strip_annotations l) (df_strip_annotations r)
- | MatrixVectorAdd m n _ l r => MatrixVectorAdd tt (df_strip_annotations l) (df_strip_annotations r)
- | MatrixMult m p n _ l r => MatrixMult tt (df_strip_annotations l) (df_strip_annotations r)
- | VectorPlus n _ l r => VectorPlus tt (df_strip_annotations l) (df_strip_annotations r)
- | VectorMinus n _ l r => VectorMinus tt (df_strip_annotations l) (df_strip_annotations r)
- | MatrixPlus m n _ l r => MatrixPlus tt (df_strip_annotations l) (df_strip_annotations r)
- | MatrixMinus m n _ l r => MatrixMinus tt (df_strip_annotations l) (df_strip_annotations r)
- | VectorScalMult n _ l r => VectorScalMult tt (df_strip_annotations l) (df_strip_annotations r)
- | MatrixScalMult m n _ l r => MatrixScalMult tt (df_strip_annotations l) (df_strip_annotations r)
- | VectorApply n _ v s l => VectorApply tt v (df_strip_annotations s) (df_strip_annotations l)
- | MatrixApply m n _ v s l => MatrixApply tt v (df_strip_annotations s) (df_strip_annotations l)
- | VLossfun n _ v1 v2 s l r => VLossfun tt v1 v2 (df_strip_annotations s) (df_strip_annotations l) r
- | MLossfun m n _ v1 v2 s l r => MLossfun tt v1 v2 (df_strip_annotations s) (df_strip_annotations l) r
- end.
-
-
- Lemma df_strip_annotations_id {T} (df:DefinedFunction UnitAnn T) : df_strip_annotations df = df.
- Proof.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case; simpl; trivial
- ; destruct ann; trivial; try congruence.
- - Case "DVector"%string.
- f_equal.
- erewrite vmap_ext; [apply vmap_id | ]; intros.
- simpl.
- destruct H0 as [??]; subst.
- eapply H; eauto.
- - Case "DMatrix"%string.
- f_equal.
- erewrite vmap_ext; [apply vmap_id | ]; intros.
- simpl.
- erewrite vmap_ext; [apply vmap_id | ]; intros.
- simpl.
- destruct H0 as [??]; subst.
- destruct H1 as [??]; subst.
- eapply H; eauto.
- Qed.
-
- Definition df_eq_upto_annotations {Ann1 Ann2 T}
- (df1:DefinedFunction Ann1 T) (df2:DefinedFunction Ann2 T) : Prop
- := df_strip_annotations df1 = df_strip_annotations df2.
-
- Definition is_df_evalann_correct_top (σ:df_env) {T} (df:DefinedFunction EvalAnn T)
- := df_eval σ df = Some (get_annotation df).
-
- Definition is_df_evalann_correct (σ:df_env) {T} (df:DefinedFunction EvalAnn T)
- := is_df_rec_prop (@is_df_evalann_correct_top σ) df.
-
- Lemma is_df_rec_prop_top {Ann} {T}
- {prop : forall TT:definition_function_types,
- (DefinedFunction Ann TT) -> Prop}
- {df:DefinedFunction Ann T} :
- is_df_rec_prop prop df ->
- prop _ df.
- Proof.
- destruct df; simpl; tauto.
- Qed.
-
- Lemma df_eval_tree_correct {T Ann} (σ:df_env) (df:DefinedFunction Ann T) (dfann:DefinedFunction EvalAnn T):
- df_eval_tree σ df = Some dfann ->
- is_df_evalann_correct σ dfann.
- Proof.
- unfold is_df_evalann_correct, is_df_evalann_correct_top.
- revert dfann.
- DefinedFunction_cases (induction df) Case; simpl; intros dfann eqq
- ; try solve[case_eq (df_eval_tree σ df1)
- ; [intros adf1 a1eqq | intros a1eqq]
- ; rewrite a1eqq in eqq
- ; [| congruence]
- ; (case_eq (df_eval_tree σ df2)
- ; [intros adf2 a2eqq | intros a2eqq]
- ; rewrite a2eqq in eqq
- ; [| congruence]
- ; inversion eqq; simpl
- ; specialize (IHdf1 _ a1eqq)
- ; specialize (IHdf2 _ a2eqq)
- ; split; [| tauto]
- ; apply is_df_rec_prop_top in IHdf1
- ; apply is_df_rec_prop_top in IHdf2
- ; simpl in IHdf1, IHdf2
- ; rewrite IHdf1, IHdf2
- ; trivial)
-
- |
- case_eq (df_eval_tree σ df)
- ; [intros adf aeqq | intros aeqq]
- ; rewrite aeqq in eqq
- ; [| congruence]
- ; inversion eqq; simpl
- ; specialize (IHdf _ aeqq)
- ; split; [| tauto]
- ; apply is_df_rec_prop_top in IHdf
- ; simpl in IHdf
- ; rewrite IHdf
- ; trivial
- ].
-
- - Case "Number"%string.
- inversion eqq; subst.
- simpl; tauto.
- - Case "Constant"%string.
- inversion eqq; subst.
- simpl; tauto.
- - Case "DVector"%string.
- match_option_in eqq.
- invcs eqq.
- simpl.
- specialize (vectoro_to_ovector_forall_some_f eqq0)
- ; simpl
- ; clear eqq0; intros eqq0.
- split.
- + apply vectoro_to_ovector_forall_some_b_strong; intros i.
- specialize (H _ _ (eqq0 i)).
- apply is_df_rec_prop_top in H.
- simpl in *.
- rewrite vmap_nth; trivial.
- + apply vforall_forall; eauto.
- - Case "DMatrix"%string.
- match_option_in eqq.
- invcs eqq.
- simpl.
- unfold matrixo_to_omatrix in *.
- specialize (vectoro_to_ovector_forall_some_f eqq0)
- ; simpl
- ; clear eqq0; intros eqq0.
- split.
- + apply vectoro_to_ovector_forall_some_b_strong; intros i.
- apply vectoro_to_ovector_forall_some_b_strong; intros j.
- specialize (eqq0 i).
- specialize (vectoro_to_ovector_forall_some_f eqq0)
- ; simpl
- ; clear eqq0; intros eqq0.
- specialize (eqq0 j).
- specialize (H _ _ _ eqq0).
- apply is_df_rec_prop_top in H.
- simpl in *.
- repeat rewrite vmap_nth; trivial.
- + apply vforall_forall; intros.
- apply vforall_forall; intros.
- specialize (eqq0 i).
- specialize (vectoro_to_ovector_forall_some_f eqq0)
- ; simpl
- ; clear eqq0; intros eqq0.
- eauto.
- - Case "Var"%string.
- revert eqq.
- case_eq (vartlookup σ v) ; [| congruence].
- intros.
- inversion eqq; subst; simpl.
- rewrite H; tauto.
- - Case "VectorApply"%string.
- case_eq (df_eval_tree σ df2)
- ; [intros adf2 a2eqq | intros a2eqq]
- ; rewrite a2eqq in eqq
- ; [| congruence].
- specialize (IHdf2 _ a2eqq).
- match_option_in eqq.
- invcs eqq.
- simpl.
- split; trivial.
- apply is_df_rec_prop_top in IHdf2.
- simpl in IHdf2.
- rewrite IHdf2; trivial.
- - Case "MatrixApply"%string.
- case_eq (df_eval_tree σ df2)
- ; [intros adf2 a2eqq | intros a2eqq]
- ; rewrite a2eqq in eqq
- ; [| congruence].
- specialize (IHdf2 _ a2eqq).
- match_option_in eqq.
- invcs eqq.
- simpl.
- split; trivial.
- apply is_df_rec_prop_top in IHdf2.
- simpl in IHdf2.
- rewrite IHdf2; trivial.
- - Case "VLossfun"%string.
- case_eq (df_eval_tree σ df2)
- ; [intros adf2 a2eqq | intros a2eqq]
- ; rewrite a2eqq in eqq
- ; [| congruence].
- specialize (IHdf2 _ a2eqq).
- match_option_in eqq.
- invcs eqq.
- simpl.
- split; trivial.
- apply is_df_rec_prop_top in IHdf2.
- simpl in IHdf2.
- rewrite IHdf2, eqq0; trivial.
- - Case "MLossfun"%string.
- case_eq (df_eval_tree σ df2)
- ; [intros adf2 a2eqq | intros a2eqq]
- ; rewrite a2eqq in eqq
- ; [| congruence].
- specialize (IHdf2 _ a2eqq).
- match_option_in eqq.
- invcs eqq.
- simpl.
- split; trivial.
- apply is_df_rec_prop_top in IHdf2.
- simpl in IHdf2.
- rewrite IHdf2, eqq0; trivial.
- Qed.
-
-
- Lemma df_eval_tree_deriv_correct {T} {σ:df_env} {df:DefinedFunction EvalAnn T} :
- is_df_evalann_correct σ df ->
- forall (xv:var_type),
- (* let xv := (v, DTfloat) in *)
- df_eval_tree_deriv σ df xv = df_eval_deriv σ df xv.
- Proof.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case;
- intro iscor; destruct iscor;
- simpl; intros; trivial; unfold is_df_evalann_correct in *
- ; try solve
- [
- rewrite IHdf
- ; trivial
- |
- assert (is_df_evalann_correct_top σ df)
- ; [ apply is_df_rec_prop_top; trivial |
- unfold is_df_evalann_correct_top in H1
- ; rewrite H1
- ; rewrite IHdf
- ; trivial
- ]
- |
- rewrite IHdf1;
- [ rewrite IHdf2
- ; trivial
- ; tauto
- |
- tauto
- ]
- |
- rewrite IHdf1;
- [ assert (is_df_evalann_correct_top σ df2);
- [ apply is_df_rec_prop_top; trivial
- | unfold is_df_evalann_correct_top in H1;
- rewrite H1; trivial]
- | tauto]
- |
- destruct H0; rewrite IHdf1;
- [rewrite IHdf2;
- [assert (is_df_evalann_correct_top σ df1);
- [apply is_df_rec_prop_top; trivial
- | assert (is_df_evalann_correct_top σ df2);
- [ apply is_df_rec_prop_top; trivial
- |
- unfold is_df_evalann_correct_top in H2;
- unfold is_df_evalann_correct_top in H3;
- rewrite H2; rewrite H3; trivial ]]
- |
- tauto]
- |
- tauto]
- ].
- - Case "DVector"%string.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality.
- intro.
- apply H.
- destruct H0.
- rewrite vforall_forall in H1.
- eauto.
- - Case "DMatrix"%string.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality.
- intro.
- apply FunctionalExtensionality.functional_extensionality.
- intros.
- apply H.
- destruct H0.
- rewrite vforall_forall in H1.
- specialize (H1 x0).
- rewrite vforall_forall in H1.
- eauto.
- Qed.
-
- Lemma df_eval_tree_backprop_deriv_correct {T} (σ gradenv:df_env) (df:DefinedFunction EvalAnn T) (grad : definition_function_types_interp T) :
- is_df_evalann_correct σ df ->
- df_eval_tree_backprop_deriv σ df gradenv grad = df_eval_backprop_deriv σ df gradenv grad.
- Proof.
- revert gradenv grad.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case;
- intros; simpl;trivial
- ; try solve [
- destruct H;
- assert (is_df_evalann_correct σ df);
- [ unfold is_df_evalann_correct; trivial
- | apply is_df_rec_prop_top in H0;
- rewrite IHdf;
- [ unfold is_df_evalann_correct_top in H0;
- rewrite H0; trivial
- | trivial]]
- |
- destruct H;
- apply IHdf;
- unfold is_df_evalann_correct; trivial
- |
- destruct H; destruct H0; rewrite IHdf1;
- [ case_eq (df_eval_backprop_deriv σ df1 gradenv grad); [|congruence];
- intros;
- rewrite IHdf2; trivial
- | unfold is_df_evalann_correct; trivial]
-
- |
- destruct H;
- assert (is_df_evalann_correct σ df2);
- [ unfold is_df_evalann_correct; trivial
- | apply is_df_rec_prop_top in H0;
- unfold is_df_evalann_correct_top in H0;
- rewrite H0;
- match_destr;
- rewrite IHdf1; trivial]
- |
- destruct H; destruct H0; assert (is_df_evalann_correct σ df1);
- [ unfold is_df_evalann_correct; trivial
- | assert (is_df_evalann_correct σ df2);
- [ unfold is_df_evalann_correct; trivial
- | rewrite IHdf1; trivial;
- apply is_df_rec_prop_top in H0;
- apply is_df_rec_prop_top in H1;
- unfold is_df_evalann_correct_top in H0;
- unfold is_df_evalann_correct_top in H1;
- rewrite H0; rewrite H1;
- match_destr;
- rewrite IHdf2; trivial]]
- ].
- - Case "DVector"%string.
- destruct H0.
- rewrite vforall_forall in H1.
- unfold two_vector_env_iter_alt.
- f_equal; apply FunctionalExtensionality.functional_extensionality; intros
- ; apply FunctionalExtensionality.functional_extensionality ; intros.
- apply H; unfold is_df_evalann_correct; apply H1.
- - Case "DMatrix"%string.
- destruct H0.
- rewrite vforall_forall in H1.
- unfold two_matrix_env_iter_alt.
- f_equal; apply FunctionalExtensionality.functional_extensionality; intros
- ; apply FunctionalExtensionality.functional_extensionality; intros.
- f_equal; apply FunctionalExtensionality.functional_extensionality; intros
- ; apply FunctionalExtensionality.functional_extensionality; intros.
- apply H; unfold is_df_evalann_correct.
- specialize (H1 x0).
- rewrite vforall_forall in H1; apply H1.
- - Case "MatrixVectorAdd"%string.
- destruct H.
- destruct H0.
- assert (is_df_evalann_correct σ df1).
- unfold is_df_evalann_correct; trivial.
- assert (is_df_evalann_correct σ df2).
- unfold is_df_evalann_correct; trivial.
- rewrite IHdf1; trivial.
- match_destr.
- assert
- (list_env_iter
- (fun (i : {m' : nat | (m' < n)%nat}) (env : df_env) =>
- df_eval_tree_backprop_deriv σ df2 env (transpose grad i))
- (Some d) (bounded_seq0 n) =
- list_env_iter
- (fun (i : {m' : nat | (m' < n)%nat}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad i)) (Some d)
- (bounded_seq0 n)).
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite IHdf2; trivial.
- rewrite H4; trivial.
- Qed.
-
- Lemma df_eval_ignores_ann {Ann T} {σ:df_env}
- (df:DefinedFunction Ann T) :
- df_eval σ df = df_eval σ (df_strip_annotations df).
- Proof.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case; simpl; trivial
- ; try solve [
- rewrite IHdf; trivial
- |
- rewrite IHdf1;
- case_eq (df_eval σ (df_strip_annotations df1)); [|congruence];
- intros; rewrite IHdf2; trivial
- |
- rewrite IHdf1; trivial;
- case_eq (df_eval σ (df_strip_annotations df2)); [|congruence];
- intros; f_equal;
- apply FunctionalExtensionality.functional_extensionality; intros;
- f_equal; rewrite df_strip_annotations_id; trivial
- |
- rewrite IHdf1; trivial;
- case_eq (df_eval σ (df_strip_annotations df2)); [|congruence];
- intros; f_equal;
- rewrite df_strip_annotations_id; trivial
- ].
-
- - Case "DVector"%string.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite H.
- rewrite vmap_nth; trivial.
- - Case "DMatrix"%string.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H x0).
- rewrite H.
- rewrite vmap_nth.
- rewrite vmap_nth; trivial.
- Qed.
-
- Lemma df_eval_ignores_ann2 {Ann1 Ann2 T} {σ:df_env}
- (df1:DefinedFunction Ann1 T) (df2:DefinedFunction Ann2 T) :
- df_eq_upto_annotations df1 df2 ->
- df_eval σ df1 = df_eval σ df2.
- Proof.
- assert (df_eval σ df1 = df_eval σ (df_strip_annotations df1)) by apply df_eval_ignores_ann.
- assert (df_eval σ df2 = df_eval σ (df_strip_annotations df2)) by apply df_eval_ignores_ann.
- congruence.
- Qed.
-
- Lemma df_eval_deriv_ignores_ann {Ann T} {σ:df_env}
- (df:DefinedFunction Ann T) :
- forall (xv:var_type),
- df_eval_deriv σ df xv = df_eval_deriv σ (df_strip_annotations df) xv.
- Proof.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case; simpl; trivial
- ; try solve
- [
- intro; rewrite IHdf1;
- case_eq (df_eval_deriv σ (df_strip_annotations df1) xv); [|congruence];
- intros;
- rewrite IHdf2; trivial
- |
- intro; rewrite df_eval_ignores_ann;
- case_eq (df_eval σ (df_strip_annotations df1)); [|congruence];
- intros; rewrite IHdf1; intros;
- case_eq (df_eval_deriv σ (df_strip_annotations df1) xv); [|congruence];
- intros; rewrite df_eval_ignores_ann;
- case_eq (df_eval σ (df_strip_annotations df2)); [|congruence];
- intros; rewrite IHdf2; trivial
- |
- intros; rewrite df_eval_ignores_ann;
- case_eq (df_eval σ (df_strip_annotations df)); [|congruence];
- intros; rewrite IHdf; intros;
- case_eq (df_eval_deriv σ (df_strip_annotations df) xv); [|congruence];
- trivial
- |
- intros; rewrite IHdf; trivial
- |
- intro; rewrite df_eval_ignores_ann;
- case_eq (df_eval σ (df_strip_annotations df2)); [|congruence];
- intros; rewrite IHdf1; intros;
- rewrite df_strip_annotations_id; trivial
- ].
-
- - Case "DVector"%string.
- intros.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite H.
- rewrite vmap_nth; trivial.
- - Case "DMatrix"%string.
- intros.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H x0).
- rewrite H.
- rewrite vmap_nth.
- rewrite vmap_nth; trivial.
- - Case "Max"%string.
- intro; rewrite df_eval_ignores_ann.
- case_eq (df_eval σ (df_strip_annotations df1)); [|congruence].
- intros.
- rewrite df_eval_ignores_ann.
- case_eq (df_eval σ (df_strip_annotations df2)); [|congruence].
- intros.
- rewrite IHdf1.
- rewrite IHdf2; trivial.
- Qed.
-
- Lemma df_eval_deriv_ignores_ann2 {Ann1 Ann2 T} {σ:df_env}
- (df1:DefinedFunction Ann1 T) (df2:DefinedFunction Ann2 T) :
- forall (xv:var_type),
- df_eq_upto_annotations df1 df2 ->
- df_eval_deriv σ df1 xv = df_eval_deriv σ df2 xv.
- Proof.
- intro.
- assert (df_eval_deriv σ df1 xv = df_eval_deriv σ (df_strip_annotations df1) xv) by apply df_eval_deriv_ignores_ann.
- assert (df_eval_deriv σ df2 xv = df_eval_deriv σ (df_strip_annotations df2) xv) by apply df_eval_deriv_ignores_ann.
- congruence.
- Qed.
-
- Lemma is_scalar_function_scalar {Ann} {T} (df:DefinedFunction Ann T) :
- is_scalar_function df -> is_scalar_df_type T.
- Proof.
- induction df; simpl; trivial.
- Qed.
-
-
-
- Definition definition_function_types_map_base (f:Type->Type) (dft:definition_function_types): Type
- := match dft with
- | DTfloat => f float
- | DTVector n => Vector (f float) n
- | DTMatrix m n => Matrix (f float) m n
- end.
-
- Definition definition_function_types_subgradient (dft:definition_function_types)
- := definition_function_types_map_base (fun t => list (list t)) dft.
-
-
- Definition df_eval_gradient {T} σ (df:DefinedFunction UnitAnn T) (lv:list var_type) : option (list (definition_function_types_interp T))
- := listo_to_olist (map (df_eval_deriv σ df) lv).
-
- Definition combine_prod (l1 l2 : list (list float)) : list (list (float * float))
- := let l12 := list_prod l1 l2
- in map (fun '(x,y) => combine x y) l12.
-(*
- Fixpoint df_eval_subgradient {dft:definition_function_types} (σ:df_env) (df:DefinedFunction dft) (lv:list SubVar) : option (definition_function_types_subgradient dft)
- := (match df with
- | Number _ => Some ((map (fun _ => 0) lv) :: nil)
- | DVector n v => vectoro_to_ovector (vmap (fun x => df_eval_subgradient σ x lv) v)
- | DMatrix n m df => matrixo_to_omatrix (vmap (fun x => vmap (fun y => df_eval_subgradient σ y lv) x) df)
- | Var x => Some ((map (fun v => if x == v then 1 else 0) lv) :: nil)
- | Plus l r =>
- match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with
- | Some ld, Some rd => Some (map (map (fun '(x, y) => x+y)) (combine_prod ld rd))
- | _, _ => None
- end
- | Minus l r =>
- match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with
- | Some ld, Some rd => Some (map (map (fun '(x, y) => x-y)) (combine_prod ld rd))
- | _, _ => None
- end
- | Times l r =>
- match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with
- | Some le, Some ld, Some re, Some rd =>
- Some (map (map (fun '(lp,rp) => lp*re + le*rp)) (combine_prod ld rd))
- | _, _, _, _ => None
- end
- | Divide l r =>
- match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with
- | Some le, Some ld, Some re, Some rd =>
- Some (map (map (fun '(lp,rp) => (lp*re - le*rp)/(re * re))) (combine_prod ld rd))
- | _, _, _, _ => None
- end
- | Square e =>
- match df_eval σ e, df_eval_subgradient σ e lv with
- | Some ee, Some ed => Some (map (map (fun pd => 2 * ee * pd)) ed)
- | _, _ => None
- end
- | Exp e =>
- match df_eval σ e, df_eval_subgradient σ e lv with
- | Some ee, Some ed => Some (map (map (fun pd => pd * Fexp ee)) ed)
- | _, _ => None
- end
- | Log e =>
- match df_eval σ e, df_eval_subgradient σ e lv with
- | Some ee, Some ed => Some (map (map (fun pd => (pd / ee))) ed)
- | _, _ => None
- end
- | Abs e =>
- match df_eval σ e, df_eval_subgradient σ e lv with
- | Some ee, Some ed =>
- if Feq ee 0 then Some (ed ++ (map (map (fun ep => -ep)) ed))
- else Some (map (map (fun ed => (ed * (sign ee)))) ed)
- | _, _ => None
- end
- | Sign e =>
- match df_eval σ e with
- | Some ee => Some ((map (fun _ => 0) lv) :: nil )
- | _ => None
- end
- | PSign e =>
- match df_eval σ e with
- | Some ee => Some ((map (fun _ => 0) lv) :: nil )
- | _ => None
- end
- | Max l r =>
- match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with
- | Some le, Some ld, Some re, Some rd =>
- if Feq le re then Some (ld ++ rd)
- else if le > re then Some ld
- else Some rd
- | _, _, _, _ => None
- end
- | VectorElem n l i =>
- match (df_eval_subgradient σ l lv) with
- | Some l' => Some (l' i)
- | _ => None
- end
- | MatrixElem m n l i j =>
- match (df_eval_subgradient σ l lv) with
- | Some l' => Some (l' i j)
- | _ => None
- end
- | VectorSum n l =>
- match df_eval_subgradient σ l lv with
- | Some l' =>
- Some (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp))
- (combine_prod a b))
- ((map (fun _ => 0) lv)::nil) l')
- | _ => None
- end
- | VectorDot n l r =>
- match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with
- | Some le, Some ld, Some re, Some rd =>
- Some (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp))
- (combine_prod a b))
- ((map (fun _ => 0) lv)::nil)
- (fun i => map (map (fun '(lp,rp) => lp*(re i) + (le i)*rp))
- (combine_prod (ld i) (rd i))))
- | _, _, _, _ => None
- end
- | VectorScalMult n x r =>
- match df_eval σ x, df_eval_subgradient σ x lv, df_eval σ r, df_eval_subgradient σ r lv with
- | Some xe, Some xd, Some re, Some rd =>
- Some (fun j => map (map (fun '(xp,rp) => xe * rp + xp * (re j))) (combine_prod xd (rd j)))
- | _, _, _, _ => None
- end
- | MatrixScalMult n m x r =>
- match df_eval σ x, df_eval_subgradient σ x lv, df_eval σ r, df_eval_subgradient σ r lv with
- | Some xe, Some xd, Some re, Some rd =>
- Some (fun i j => map (map (fun '(xp,rp) => xe * rp + xp * (re i j))) (combine_prod xd (rd i j)))
-
- | _, _, _, _ => None
- end
- | MatrixVectorMult n m l r =>
- match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with
- | Some le, Some ld, Some re, Some rd =>
- Some (fun i =>
- (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp))
- (combine_prod a b))
- ((map (fun _ => 0) lv)::nil)
- (fun j => map (map (fun '(lp,rp) => lp*(re j) + (le i j)*rp))
- (combine_prod (ld i j) (rd j)))))
- | _, _, _, _ => None
- end
- | MatrixMult n m p l r =>
- match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with
- | Some le, Some ld, Some re, Some rd =>
- Some (fun i k =>
- (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp))
- (combine_prod a b))
- ((map (fun _ => 0) lv)::nil)
- (fun j => map (map (fun '(lp,rp) => lp*(re j k) + (le i j)*rp))
- (combine_prod (ld i j) (rd j k)))))
- | _, _, _, _ => None
- end
- | VectorPlus n l r =>
- match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with
- | Some ld, Some rd => Some (fun i => (map (map (fun '(x, y) => x+y)) (combine_prod (ld i) (rd i))))
- | _, _ => None
- end
- | VectorMinus n l r =>
- match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with
- | Some ld, Some rd => Some (fun i => (map (map (fun '(x, y) => x-y)) (combine_prod (ld i) (rd i))))
- | _, _ => None
- end
- | MatrixPlus n m l r =>
- match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with
- | Some ld, Some rd => Some (fun i j => (map (map (fun '(x, y) => x+y)) (combine_prod (ld i j) (rd i j))))
- | _, _ => None
- end
- | MatrixMinus n m l r =>
- match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with
- | Some ld, Some rd => Some (fun i j => (map (map (fun '(x, y) => x-y)) (combine_prod (ld i j) (rd i j))))
- | _, _ => None
- end
- | VectorApply n x s r =>
- match df_eval σ r, df_eval_subgradient σ r lv with
- | Some re, Some rd =>
- vectoro_to_ovector
- (fun i => match df_eval_subgradient (cons (x, re i) σ) s lv with
- | Some sd =>
- Some (map (map (fun '(x, y) => x*y)) (combine_prod (rd i) sd))
- | _ => None
- end)
- | _, _ => None
- end
- | Lossfun n v1 v2 s l r =>
- match df_eval σ l, df_eval_subgradient σ l lv with
- | Some le, Some ld =>
- match (vectoro_to_ovector
- (fun i => match df_eval_subgradient (cons (v1, (le i)) (cons (v2, r i) σ)) s lv with
- | Some sd => Some (map (map (fun '(x, y) => x*y)) (combine_prod (ld i) sd))
- | _ => None
- end)) with
- | Some vv => Some (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp))
- (combine_prod a b))
- ((map (fun _ => 0) lv)::nil) vv)
- | _ => None
- end
- | _, _ => None
- end
- end).
-*)
- End deriv2.
-
- Definition dft_one (dft:definition_function_types) : definition_function_types_interp dft
- := match dft with
- | DTfloat => 1
- | DTVector n => fun _ => 1
- | DTMatrix m n => fun _ _ => 1
- end.
-
- Section scalar_ind.
-
- Fixpoint is_scalar_function_ind_gen {Ann}
- {P:forall {T}, DefinedFunction Ann T->Prop}
- (fnumber:forall ann x, P (Number ann x))
- (fconstant:forall (ann:Ann DTfloat) x, P (@Constant _ DTfloat ann x))
- (fvar:forall sv ann, P (@Var _ (sv,DTfloat) ann))
- (fplus:forall a l r, P l -> P r -> P (Plus a l r))
- (fminus:forall a l r, P l -> P r -> P (Minus a l r))
- (ftimes:forall a l r, P l -> P r -> P (Times a l r))
- (fdivide:forall a l r, P l -> P r -> P (Divide a l r))
- (fsquare:forall a e, P e -> P (Square a e))
- (fexp:forall a e, P e -> P (Exp a e))
- (flog:forall a e, P e -> P (Log a e))
- (fabs:forall a e, P e -> P (Abs a e))
- (fsign:forall a e, P e -> P (Sign a e))
- (fpsign:forall a e, P e -> P (PSign a e))
- (fmax:forall a l r, P l -> P r -> P (Max a l r))
- {T}
- (df:DefinedFunction Ann T) {struct df} : is_scalar_function df -> P df.
- Proof.
- induction df; simpl; intros isc; try tauto.
- - apply fnumber.
- - destruct t; simpl in isc; try tauto.
- apply fconstant.
- - destruct v.
- destruct d; simpl in isc; try tauto.
- apply fvar.
- - apply fplus.
- + apply IHdf1; tauto.
- + apply IHdf2; tauto.
- - apply fminus.
- + apply IHdf1; tauto.
- + apply IHdf2; tauto.
- - apply ftimes.
- + apply IHdf1; tauto.
- + apply IHdf2; tauto.
- - apply fdivide.
- + apply IHdf1; tauto.
- + apply IHdf2; tauto.
- - apply fsquare.
- + apply IHdf; tauto.
- - apply fexp.
- + apply IHdf; tauto.
- - apply flog.
- + apply IHdf; tauto.
- - apply fabs.
- + apply IHdf; tauto.
- - apply fsign.
- + apply IHdf; tauto.
- - apply fpsign.
- + apply IHdf; tauto.
- - apply fmax.
- + apply IHdf1; tauto.
- + apply IHdf2; tauto.
- Qed.
-
- Definition is_scalar_function_ind {Ann}
- {P:DefinedFunction Ann DTfloat->Prop}
- (fnumber:forall ann x, P (Number ann x))
- (fconstant:forall (ann:Ann DTfloat) x, P (@Constant _ DTfloat ann x))
- (fvar:forall sv ann, P (@Var _ (sv,DTfloat) ann))
- (fplus:forall a l r, P l -> P r -> P (Plus a l r))
- (fminus:forall a l r, P l -> P r -> P (Minus a l r))
- (ftimes:forall a l r, P l -> P r -> P (Times a l r))
- (fdivide:forall a l r, P l -> P r -> P (Divide a l r))
- (fsquare:forall a e, P e -> P (Square a e))
- (fexp:forall a e, P e -> P (Exp a e))
- (flog:forall a e, P e -> P (Log a e))
- (fabs:forall a e, P e -> P (Abs a e))
- (fsign:forall a e, P e -> P (Sign a e))
- (fpsign:forall a e, P e -> P (PSign a e))
- (fmax:forall a l r, P l -> P r -> P (Max a l r))
- (df:DefinedFunction Ann DTfloat) : is_scalar_function df -> P df.
- Proof.
- apply (@is_scalar_function_ind_gen _ (fun t => match t with
- | DTfloat => fun df => P df
- | _ => fun _ => False
- end)); trivial.
- Qed.
-
- Definition vartlookup_eq (l1 l2:df_env) : Prop := forall a, vartlookup l1 a = vartlookup l2 a.
-
- Global Instance vartlookup_eq_equiv : Equivalence vartlookup_eq.
- Proof.
- unfold vartlookup_eq.
- constructor; red.
- - intros; reflexivity.
- - intros; eauto.
- - intro; etransitivity; eauto.
- Qed.
-
- End scalar_ind.
-
- Lemma lookup_update (xv : var_type) (gradenv : df_env)
- (val : definition_function_types_interp (snd xv)) :
- vartlookup (vart_update gradenv xv val) xv = Some val.
- Proof.
- induction gradenv; simpl.
- - destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence].
- refl_simpler; simpl; trivial.
- - destruct a; simpl.
- case_eq (@equiv_dec var_type _ _ _ xv x); simpl; intros.
- + destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence].
- refl_simpler; simpl; trivial.
- + rewrite H; trivial.
- Qed.
-
- Lemma lookup_update_neq (xv1 xv2 : var_type) (gradenv : df_env)
- (val : definition_function_types_interp (snd xv1)) : xv1 <> xv2 ->
- vartlookup (vart_update gradenv xv1 val) xv2 = vartlookup gradenv xv2.
- Proof.
- intros neq.
- induction gradenv; simpl.
- - destruct (@equiv_dec var_type _ _ _ xv2 xv1); congruence.
- - destruct a; simpl.
- case_eq (@equiv_dec var_type _ _ _ xv1 x); simpl; intros.
- + destruct (@equiv_dec var_type _ _ _ xv2 xv1); [congruence | ].
- destruct (@equiv_dec var_type _ _ _ xv2 x); congruence.
- + destruct (@equiv_dec var_type _ _ _ xv2 x); congruence.
- Qed.
-
- Lemma lookup_update2 (xv : var_type) (gradenv : df_env)
- (val : definition_function_types_interp (snd xv)) :
- vartlookup ((mk_env_entry xv val) :: gradenv) xv = Some val.
- Proof.
- induction gradenv; simpl.
- - destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence].
- refl_simpler; simpl; trivial.
- - destruct a; simpl.
- case_eq (@equiv_dec var_type _ _ _ xv x); simpl; intros.
- + destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence].
- refl_simpler; simpl; trivial.
- + destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence].
- refl_simpler; simpl; trivial.
- Qed.
-
-
-Tactic Notation "DefinedFunction_scalar_cases" tactic(first) ident(c) :=
- first;
- [ Case_aux c "Number"%string
- | Case_aux c "Constant"%string
- | Case_aux c "Var"%string
- | Case_aux c "Plus"%string
- | Case_aux c "Minus"%string
- | Case_aux c "Times"%string
- | Case_aux c "Divide"%string
- | Case_aux c "Square"%string
- | Case_aux c "Exp"%string
- | Case_aux c "Log"%string
- | Case_aux c "Abs"%string
- | Case_aux c "Sign"%string
- | Case_aux c "PSign"%string
- | Case_aux c "Max"%string].
-
- Lemma df_eval_backprop_deriv_preserves_lookup_not_none {Ann T} {env} {grad gradenv d} {df:DefinedFunction Ann T} :
- df_eval_backprop_deriv env df gradenv grad = Some d ->
- forall xv,
- vartlookup gradenv xv <> None ->
- vartlookup d xv <> None.
- Proof.
- simpl.
- revert grad gradenv d.
- DefinedFunction_cases (induction df) Case; simpl.
- - Case "Number"%string; intros; inversion H; subst; easy.
- - Case "Constant"%string; intros; inversion H; subst; easy.
- - Case "DVector"%string.
- intros grad.
- unfold two_vector_env_iter_alt.
- induction (bounded_seq0 n).
- simpl.
- intros.
- inversion H0; subst; trivial.
- simpl.
- intros gradenv d.
- case_eq (df_eval_backprop_deriv env (x a) gradenv (grad a)).
- intros.
- specialize (H a (grad a) gradenv d0).
- specialize (IHl d0 d).
- apply IHl; trivial.
- apply H; trivial.
- intros.
- assert (list_env_iter
- (fun (i : {n' : nat | (n' < n)%nat}) (env0 : df_env) =>
- df_eval_backprop_deriv env (x i) env0 (grad i)) None l = None)
- by apply list_env_iter_none.
- intros; rewrite H1 in H3; discriminate.
- - Case "DMatrix"%string.
- intros grad.
- unfold two_matrix_env_iter_alt.
- induction (bounded_seq0 n); simpl.
- { intros; inversion H0; subst; trivial. }
- intros gradenv d eqq.
- case_eq ((list_env_iter
- (fun (j : {m' : nat | (m' < m)%nat}) (env0 : df_env) =>
- df_eval_backprop_deriv env (x a j) env0 (grad a j)) (Some gradenv)
- (bounded_seq0 m)))
- ; [ intros dd ddeqq | intros ddeqq]
- ; rewrite ddeqq in eqq
- ; simpl in eqq
- ; [| destruct l; simpl; discriminate].
- specialize (IHl _ _ eqq).
- cut (forall xv : var_type, vartlookup gradenv xv <> None -> vartlookup dd xv <> None)
- ; [ eauto | ].
- clear d IHl eqq.
- revert gradenv dd ddeqq.
- induction (bounded_seq0 m); simpl
- ; intros gradenv dd ddeqq
- ; simpl in ddeqq.
- { inversion ddeqq; subst; trivial. }
- case_eq (df_eval_backprop_deriv env (x a a0) gradenv (grad a a0))
- ; [intros dd2 ddeqq2 | intros ddeqq2]
- ; rewrite ddeqq2 in ddeqq
- ; simpl in ddeqq
- ; [| destruct l0; simpl; discriminate].
- eauto.
- - Case "Var"%string.
- intros.
- destruct (vartlookup gradenv v) ; [|congruence].
- intros.
- inversion H.
- destruct (vart_dec v xv).
- + subst; rewrite lookup_update.
- discriminate.
- + rewrite lookup_update_neq; trivial.
- - Case "Plus"%string.
- intros grad gradenv.
- case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence].
- intros.
- specialize (IHdf1 grad gradenv d).
- specialize (IHdf2 grad d d0).
- apply IHdf2; trivial.
- apply IHdf1; trivial.
- - Case "Minus"%string.
- intros grad gradenv.
- case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence].
- intros.
- specialize (IHdf1 grad gradenv d).
- specialize (IHdf2 (-grad) d d0).
- apply IHdf2; trivial.
- apply IHdf1; trivial.
- - Case "Times"%string.
- intros grad gradenv d.
- destruct (df_eval env df1) ; [|congruence].
- destruct (df_eval env df2) ; [|congruence].
- case_eq (df_eval_backprop_deriv env df1 gradenv (d1 * grad)) ; [|congruence].
- intros d2.
- specialize (IHdf1 (d1 * grad) gradenv d2).
- specialize (IHdf2 (d0 * grad) d2 d).
- intros.
- apply IHdf2; trivial.
- apply IHdf1; trivial.
- - Case "Divide"%string.
- intros grad gradenv d.
- destruct (df_eval env df1) ; [|congruence].
- destruct (df_eval env df2) ; [|congruence].
- case_eq (df_eval_backprop_deriv env df1 gradenv (grad / d1)) ; [|congruence].
- intros d2.
- specialize (IHdf1 (grad / d1) gradenv d2).
- specialize (IHdf2 (- d0 / (d1 * d1) * grad) d2 d).
- intros.
- apply IHdf2; trivial.
- apply IHdf1; trivial.
- - Case "Square"%string; intros.
- destruct (df_eval env df) ; [|congruence].
- specialize (IHdf (2 * d0 * grad) gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "Exp"%string; intros.
- destruct (df_eval env df) ; [|congruence].
- specialize (IHdf (grad * Fexp d0) gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "Log"%string; intros.
- destruct (df_eval env df) ; [|congruence].
- specialize (IHdf (grad / d0) gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "Abs"%string; intros.
- destruct (df_eval env df) ; [|congruence].
- specialize (IHdf (grad * sign d0) gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "Sign"%string; intros.
- specialize (IHdf 0 gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "PSign"%string; intros.
- specialize (IHdf 0 gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "Max"%string; intros.
- destruct (df_eval env df1) ; [|congruence].
- destruct (df_eval env df2) ; [|congruence].
- destruct (d0 <= d1).
- specialize (IHdf2 grad gradenv d).
- apply IHdf2.
- apply H.
- trivial.
- specialize (IHdf1 grad gradenv d).
- apply IHdf1.
- apply H.
- trivial.
- - Case "VectorDot"%string.
- intros grad gradenv d.
- destruct (df_eval env df1) ; [|congruence].
- destruct (df_eval env df2) ; [|congruence].
- case_eq (df_eval_backprop_deriv env df1 gradenv
- (vmap (fun rv : float => rv * grad) d1)); [|congruence].
- intros.
- specialize (IHdf1 (vmap (fun rv : float => rv * grad) d1) gradenv d2).
- specialize (IHdf2 (vmap (fun lv : float => lv * grad) d0) d2 d).
- apply IHdf2.
- apply H0.
- apply IHdf1.
- apply H.
- apply H1.
- - Case "VectorSum"%string.
- intros.
- specialize (IHdf (ConstVector n grad) gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "MatrixSum"%string.
- intros.
- specialize (IHdf (ConstMatrix m n grad) gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "VectorElem"%string.
- intros.
- specialize (IHdf (fun k : {n' : nat | (n' < n)%nat} =>
- if equiv_dec (proj1_sig k) (proj1_sig i)
- then grad else 0) gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "MatrixElem"%string.
- intros.
- specialize (IHdf (fun (k1 : {n' : nat | (n' < m)%nat})
- (k2 : {m' : nat | (m' < n)%nat}) =>
- if equiv_dec (proj1_sig k1) (proj1_sig i)
- then if equiv_dec (proj1_sig k2) (proj1_sig j) then grad else 0
- else 0) gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "MatrixVectorMult"%string.
- intros grad gradenv d.
- destruct (df_eval env df1) ; [|congruence].
- destruct (df_eval env df2) ; [|congruence].
- case_eq (df_eval_backprop_deriv env df1 gradenv
- (fun (i : {n' : nat | (n' < m)%nat})
- (j : {m' : nat | (m' < n)%nat}) => grad i * d1 j)) ; [|congruence].
- intros d2.
- specialize (IHdf1 (fun (i : {n' : nat | (n' < m)%nat})
- (j : {m' : nat | (m' < n)%nat}) => grad i * d1 j) gradenv d2).
- specialize (IHdf2 (matrix_vector_mult
- (fun (i : {n' : nat | (n' < n)%nat})
- (j : {m' : nat | (m' < m)%nat}) => d0 j i) grad) d2 d).
- intros.
- apply IHdf2; trivial.
- apply IHdf1; trivial.
- - Case "MatrixVectorAdd"%string.
- intros grad gradenv d.
- case_eq (df_eval_backprop_deriv env df1 gradenv grad); [|congruence].
- intros d0 casedf1.
- specialize (IHdf1 _ _ _ casedf1).
- clear casedf1.
- revert gradenv d d0 IHdf1.
- induction (bounded_seq0 n).
- + simpl.
- intros.
- inversion H; subst.
- eauto.
- + intros gradenv d d0 d0eqq.
- simpl.
- case_eq (df_eval_backprop_deriv env df2 d0 (transpose grad a)); simpl
- ; [intros ? eqq1 | intros eqq1].
- * intros.
- { apply (IHl d0 _ d1); trivial.
- - eapply IHdf2; eauto.
- - eauto.
- }
- * destruct l; simpl; discriminate.
- - Case "MatrixMult"%string.
- intros grad gradenv d.
- destruct (df_eval env df1) ; [|congruence].
- destruct (df_eval env df2) ; [|congruence].
- case_eq (df_eval_backprop_deriv env df1 gradenv
- (matrix_mult grad
- (fun (i : {n' : nat | (n' < n)%nat})
- (j : {m' : nat | (m' < p)%nat}) => d1 j i)))
- ; [|congruence].
- intros d2.
- specialize (IHdf1 (matrix_mult grad
- (fun (i : {n' : nat | (n' < n)%nat})
- (j : {m' : nat | (m' < p)%nat}) => d1 j i))
- gradenv d2).
- specialize (IHdf2 (matrix_mult
- (fun (i : {n' : nat | (n' < p)%nat})
- (j : {m' : nat | (m' < m)%nat}) => d0 j i) grad) d2 d).
- intros.
- apply IHdf2; trivial.
- apply IHdf1; trivial.
- - Case "VectorPlus"%string.
- intros grad gradenv d.
- case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence].
- intros.
- specialize (IHdf1 grad gradenv d0).
- specialize (IHdf2 grad d0 d).
- apply IHdf2.
- apply H0.
- apply IHdf1.
- apply H.
- apply H1.
- - Case "VectorMinus"%string.
- intros grad gradenv d.
- case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence].
- intros.
- specialize (IHdf1 grad gradenv d0).
- specialize (IHdf2 (fun i : {n' : nat | (n' < n)%nat} => - grad i) d0 d).
- apply IHdf2.
- apply H0.
- apply IHdf1.
- apply H.
- apply H1.
- - Case "MatrixPlus"%string.
- intros grad gradenv d.
- case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence].
- intros.
- specialize (IHdf1 grad gradenv d0).
- specialize (IHdf2 grad d0 d).
- apply IHdf2.
- apply H0.
- apply IHdf1.
- apply H.
- apply H1.
- - Case "MatrixMinus"%string.
- intros grad gradenv d.
- case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence].
- intros.
- specialize (IHdf1 grad gradenv d0).
- specialize (IHdf2 (fun (i : {n' : nat | (n' < n)%nat})
- (j : {m' : nat | (m' < m)%nat}) => - grad i j)
- d0 d).
- apply IHdf2.
- apply H0.
- apply IHdf1.
- apply H.
- apply H1.
- - Case "VectorScalMult"%string.
- intros grad gradenv d.
- destruct (df_eval env df1) ; [|congruence].
- destruct (df_eval env df2) ; [|congruence].
- case_eq (df_eval_backprop_deriv env df1 gradenv
- (vsum (fun j : {n' : nat | (n' < n)%nat} => d1 j * grad j)))
- ; [|congruence].
- intros d2.
- specialize (IHdf1 (vsum (fun j : {n' : nat | (n' < n)%nat} => d1 j * grad j))
- gradenv d2).
- specialize (IHdf2 (fun j : {n' : nat | (n' < n)%nat} => d0 * grad j)
- d2 d).
- intros.
- apply IHdf2; trivial.
- apply IHdf1; trivial.
- - Case "MatrixScalMult"%string.
- intros grad gradenv d.
- destruct (df_eval env df1) ; [|congruence].
- destruct (df_eval env df2) ; [|congruence].
- case_eq (df_eval_backprop_deriv env df1 gradenv
- (msum
- (fun (i : {n' : nat | (n' < n)%nat})
- (j : {m' : nat | (m' < m)%nat}) =>
- d1 i j * grad i j)))
- ; [|congruence].
- intros d2.
- specialize (IHdf1 (msum
- (fun (i : {n' : nat | (n' < n)%nat})
- (j : {m' : nat | (m' < m)%nat}) =>
- d1 i j * grad i j))
- gradenv d2).
- specialize (IHdf2 (fun (i : {n' : nat | (n' < n)%nat})
- (j : {m' : nat | (m' < m)%nat}) => grad i j * d0)
- d2 d).
- intros.
- apply IHdf2; trivial.
- apply IHdf1; trivial.
- - Case "VectorApply"%string.
- intros grad gradenv d.
- destruct (df_eval env df2) ; [|congruence].
- simpl in *.
- match_destr; simpl; eauto.
- - Case "MatrixApply"%string.
- intros grad gradenv d.
- destruct (df_eval env df2) ; [|congruence].
- match_destr; simpl; eauto.
- - Case "VLossfun"%string.
- intros grad gradenv d.
- destruct (df_eval env df2) ; [|congruence].
- match_destr; simpl; eauto.
- - Case "MLossfun"%string.
- intros grad gradenv d.
- destruct (df_eval env df2) ; [|congruence].
- match_destr; simpl; eauto.
- Qed.
-
- Definition df_eval_deriv_gen_top {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v: var_type) :
- option (lifted_type (definition_function_types_interp T) (snd v)) :=
- match (snd v) as vt return option (lifted_type (definition_function_types_interp T) vt) with
- | DTfloat => df_eval_deriv_genvar σ df ((mk_env_entry (fst v, DTfloat) 1)::nil)
- | DTVector n =>
- vectoro_to_ovector
- (fun i => df_eval_deriv_genvar σ df ((mk_env_entry (fst v, DTVector n) (UnitVector n i))::nil))
- | DTMatrix n m =>
- matrixo_to_omatrix
- (fun i j => df_eval_deriv_genvar σ df ((mk_env_entry (fst v, DTMatrix n m) (UnitMatrix n m i j))::nil))
- end.
-
- Program Definition subvar (x : var_type) (grad_env:df_env) :=
- (match snd x as y return snd x = y ->
- definition_function_types_interp y ->
- definition_function_types_interp y with
- | DTfloat => fun pf grad => match vartlookup grad_env x with
- | Some val => ((coerce _ val):float) - grad
- | _ => Fopp grad
- end
- | DTVector n => fun pf grad => match vartlookup grad_env x with
- | Some val => fun i => (((coerce _ val):Vector float n) i) - (grad i)
- | _ => vmap Fopp grad
- end
- | DTMatrix m n => fun pf grad => match vartlookup grad_env x with
- | Some val => fun i j => (((coerce _ val):Matrix float m n) i j) - (grad i j)
- | _ => vmap (vmap Fopp) grad
- end
- end) (eq_refl _).
- Next Obligation.
- rewrite pf; reflexivity.
- Defined.
- Next Obligation.
- rewrite pf; reflexivity.
- Defined.
- Next Obligation.
- rewrite pf; reflexivity.
- Defined.
-
- Definition df_eval_backprop_delta {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v: var_type) (grad_env:df_env) (grad: definition_function_types_interp T) :
- option (definition_function_types_interp (snd v)) :=
- match vartlookup grad_env v with
- | Some old =>
- lift (fun e => subvar v e old) (df_eval_backprop_deriv σ df grad_env grad)
- | None => None
- end.
-
- Program Definition df_eval_backward_gen_top {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v: var_type) (grad_env:df_env) :
- option (lifted_type (definition_function_types_interp (snd v)) T) :=
- match vartlookup grad_env v with
- | Some old =>
- (match T as vt return T = vt -> option (lifted_type (definition_function_types_interp (snd v)) vt) with
- | DTfloat => fun pf => (lift (fun e => subvar v e old) (df_eval_backprop_deriv σ df grad_env (coerce _ 1)))
- | DTVector n => fun pf =>
- vectoro_to_ovector
- (fun i => lift (fun e => subvar v e old) (df_eval_backprop_deriv σ df grad_env (coerce _ (UnitVector n i))))
- | DTMatrix m n => fun pf =>
- matrixo_to_omatrix
- (fun i j => lift (fun e => subvar v e old) (df_eval_backprop_deriv σ df grad_env (coerce _ (UnitMatrix m n i j))))
- end) (eq_refl _)
- | None => None
- end.
-
- Definition transpose_lifted_type {T1 T2} :
- lifted_type (definition_function_types_interp T1) T2 ->
- lifted_type (definition_function_types_interp T2) T1
- := match T1, T2 with
- | DTfloat, _ => fun inp => inp
- | _, DTfloat => fun inp => inp
- | DTVector n1, DTVector n2 => fun inp => fun i j => inp j i
- | DTMatrix m1 n1, DTMatrix m2 n2 => fun inp => fun i j p q => inp p q i j
- | DTVector n1, DTMatrix m2 n2 => fun inp => fun i p q => inp p q i
- | DTMatrix m1 n1, DTVector n2 => fun inp => fun i j p => inp p i j
- end.
- Section deriv_deriv.
- End deriv_deriv.
-
- Section max_derived.
- Definition MaxDerived (a b : DefinedFunction UnitAnn DTfloat) :=
- Divide tt (Plus tt (Plus tt (Abs tt (Minus tt b a)) b) a) (Number tt 2).
-
- Delimit Scope df_scope with df.
-
- Notation "x + y" := (Plus x y) (only printing) : df_scope.
- Notation "x - y" := (Minus x y) (only printing) : df_scope.
- Notation "x / y" := (Divide x y) (only printing) : df_scope.
- Notation "x * y" := (Times x y) (only printing) : df_scope.
- Notation "x" := (Number x) (only printing, at level 0) : df_scope.
- Notation "x" := (Var x) (only printing, at level 0) : df_scope.
- Notation "'|' x '|'" := (Abs x) (only printing, at level 0) : df_scope.
-
- End max_derived.
-
- Section fv.
-
- Fixpoint df_free_variables {Ann} {T} (f : DefinedFunction Ann T) : list var_type
- := match f with
- | Number _ x => nil
- | DVector n _ x => vlconcat_map df_free_variables x
- | Constant t _ x => nil
- | DMatrix n m _ x => vlconcat_map (fun a => vlconcat_map df_free_variables a) x
- | Var v _ => v::nil
- | Plus _ l r => (df_free_variables l) ++ (df_free_variables r)
- | Minus _ l r => (df_free_variables l) ++ (df_free_variables r)
- | Times _ l r => (df_free_variables l) ++ (df_free_variables r)
- | Divide _ l r => (df_free_variables l) ++ (df_free_variables r)
- | Max _ l r => (df_free_variables l) ++ (df_free_variables r)
- | Abs _ e => df_free_variables e
- | Sign _ e => df_free_variables e
- | PSign _ e => df_free_variables e
- | Log _ e => df_free_variables e
- | Square _ e => df_free_variables e
- | Exp _ e => df_free_variables e
- | VectorElem n _ l i => df_free_variables l
- | MatrixElem m n _ l i j => df_free_variables l
- | VectorDot n _ l r => (df_free_variables l) ++ (df_free_variables r)
- | VectorSum n _ l => df_free_variables l
- | MatrixSum n m _ l => df_free_variables l
- | VectorScalMult n _ x r => (df_free_variables x) ++ (df_free_variables r)
- | MatrixScalMult n m _ x r => (df_free_variables x) ++ (df_free_variables r)
- | MatrixVectorMult n m _ l r => (df_free_variables l) ++ (df_free_variables r)
- | MatrixVectorAdd n m _ l r => (df_free_variables l) ++ (df_free_variables r)
- | MatrixMult n m p _ l r => (df_free_variables l) ++ (df_free_variables r)
- | VectorPlus n _ l r => (df_free_variables l) ++ (df_free_variables r)
- | VectorMinus n _ l r => (df_free_variables l) ++ (df_free_variables r)
- | MatrixPlus n m _ l r => (df_free_variables l) ++ (df_free_variables r)
- | MatrixMinus n m _ l r => (df_free_variables l) ++ (df_free_variables r)
- | VectorApply n _ x s l => (remove_all (x,DTfloat) (df_free_variables s))
- ++ (df_free_variables l)
- | MatrixApply n m _ x s l => (remove_all (x,DTfloat) (df_free_variables s))
- ++ (df_free_variables l)
- | VLossfun n _ v1 v2 s l r => (remove_all (v1,DTfloat) (remove_all (v2,DTfloat) (df_free_variables s)))
- ++ (df_free_variables l)
- | MLossfun n m _ v1 v2 s l r => (remove_all (v1,DTfloat) (remove_all (v2,DTfloat) (df_free_variables s)))
- ++ (df_free_variables l)
- end.
-
- Definition df_closed {Ann} {T} (f: DefinedFunction Ann T) : Prop
- := match df_free_variables f with
- | nil => True
- | _ => False
- end.
-
- Lemma df_closed_nil {T} (f: DefinedFunction UnitAnn T) : df_closed f -> df_free_variables f = nil.
- Proof.
- unfold df_closed.
- destruct (df_free_variables f); tauto.
- Qed.
-
- Definition df_closed_over {Ann} {T} (f : DefinedFunction Ann T) (vl : list var_type) : Prop
- := incl (df_free_variables f) vl.
-
- Fixpoint fully_closed_over {Ann} {T} (df : DefinedFunction Ann T) (vl : list var_type) : Prop
- :=
- match df with
- | Number _ x => True
- | DVector n _ x => vforall (fun f => fully_closed_over f vl) x
- | Constant t _ x => True
- | DMatrix n m _ x => vforall (fun row =>
- (vforall (fun f => fully_closed_over f vl) row)) x
- | Var v _ => In v vl
- | Plus _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | Minus _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | Times _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | Divide _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | Max _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | Abs _ e => fully_closed_over e vl
- | Sign _ e => fully_closed_over e vl
- | PSign _ e => fully_closed_over e vl
- | Log _ e => fully_closed_over e vl
- | Square _ e => fully_closed_over e vl
- | Exp _ e => fully_closed_over e vl
- | VectorElem n _ l i => fully_closed_over l vl
- | MatrixElem m n _ l i j => fully_closed_over l vl
- | VectorDot n _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | VectorSum n _ l => fully_closed_over l vl
- | MatrixSum n m _ l => fully_closed_over l vl
- | VectorScalMult n _ x r => (fully_closed_over x vl) /\ (fully_closed_over r vl)
- | MatrixScalMult n m _ x r => (fully_closed_over x vl) /\ (fully_closed_over r vl)
- | MatrixVectorMult n m _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | MatrixVectorAdd n m _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | MatrixMult n m p _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | VectorPlus n _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | VectorMinus n _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | MatrixPlus n m _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | MatrixMinus n m _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | VectorApply n _ x s l => (fully_closed_over s ((x,DTfloat)::nil)) /\
- (fully_closed_over l vl)
- | MatrixApply n m _ x s l => (fully_closed_over s ((x,DTfloat)::nil)) /\
- (fully_closed_over l vl)
- | VLossfun n _ v1 v2 s l r => (fully_closed_over s ((v1,DTfloat)::(v2,DTfloat)::nil))
- /\ (fully_closed_over l vl)
- | MLossfun n m _ v1 v2 s l r => (fully_closed_over s ((v1,DTfloat)::(v2,DTfloat)::nil))
- /\ (fully_closed_over l vl)
- end.
-
- Definition In_compat_map (f : list var_type -> list var_type) : Prop :=
- forall (v : var_type) (vl : list var_type),
- In v vl -> In v (f vl).
-
- Definition map_tl (f : list var_type -> list var_type) (vl : list var_type) :=
- match vl with
- | a :: vl1 => a :: f vl1
- | _ => f vl
- end.
-
- Lemma In_compat_map_tl (f : list var_type -> list var_type) :
- In_compat_map f -> In_compat_map (map_tl f).
- Proof.
- unfold In_compat_map; intros.
- destruct vl.
- + now simpl.
- + simpl in *.
- destruct H0.
- * now left.
- * right; now apply H.
- Qed.
-
- Lemma fully_closed_over_map {T} (df : DefinedFunction UnitAnn T) (vl : list var_type) (f : list var_type -> list var_type) :
- In_compat_map f -> fully_closed_over df vl -> fully_closed_over df (f vl).
- Proof.
- revert f; revert vl.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case
- ; simpl; intros; try solve [
- trivial
- |
- apply IHdf; trivial
- |
- split; destruct H0;
- [apply IHdf1; trivial
- | apply IHdf2; trivial]
- ].
- - Case "DVector"%string.
- apply vforall_forall; intros.
- apply H; trivial.
- now rewrite vforall_forall in H1.
- - Case "DMatrix"%string.
- apply vforall_forall; intros.
- apply vforall_forall; intros.
- apply H; trivial.
- rewrite vforall_forall in H1.
- specialize (H1 i).
- now rewrite vforall_forall in H1.
- - now apply H.
- - Case "VectorApply"%string.
- split; destruct H0; trivial.
- now apply IHdf2.
- - Case "MatrixApply"%string.
- split; destruct H0; trivial.
- now apply IHdf2.
- - Case "VLossfun"%string.
- split; destruct H0; trivial.
- now apply IHdf2.
- - Case "MLossfun"%string.
- split; destruct H0; trivial.
- now apply IHdf2.
- Qed.
-
- (*
- Lemma closed_is_fully_closed {Ann} {T} (df : DefinedFunction Ann T) (vl : list var_type) :
- df_closed_over df vl <-> fully_closed_over df vl.
-*)
-
-
-(*
- Lemma df_subst_nfree {T} (e: DefinedFunction T) (v:SubVar) (e':DefinedFunction DTfloat) :
- ~ In v (df_free_variables e) ->
- df_subst e v e' = e.
- Proof.
- DefinedFunction_cases (induction e) Case; simpl; trivial; intros nin
- ; try solve [try rewrite in_app_iff in nin
- ; intuition congruence].
- - Case "DVector"%string.
- f_equal.
- apply functional_extensionality.
- intros x0.
- apply H.
- intros inn.
- apply nin.
- unfold vlconcat_map, vlconcat.
- apply concat_In.
- exists ((df_free_variables (x x0))).
- split; trivial.
- apply vector_to_list_In.
-
- - Case "DMatrix"%string.
-
- - Case "Var"%string.
- destruct (var_dec v0 v); intuition.
- Qed.
-
- Lemma df_eval_complete' {T} (σ:df_env) (f:DefinedFunction T) :
- incl (df_free_variables f) (domain σ) -> {v | df_eval σ f = Some v}.
- Proof.
- induction f; simpl; intros inc
- ; try solve [rewrite <- incl_app_iff in inc
- ; intuition
- ; destruct X as [v1 ev1]
- ; destruct X0 as [v2 ev2]
- ; rewrite ev1; rewrite ev2
- ; eauto
- | intuition
- ; destruct X as [v1 ev1]
- ; rewrite ev1
- ; eauto].
- - eauto.
- - apply in_dom_lookup_strong.
- specialize (inc v); simpl in *.
- intuition.
- Qed.
-
- (* This version has better computational properties *)
- Lemma df_eval_complete (σ:df_env) (f:DefinedFunction) :
- incl (df_free_variables f) (domain σ) -> {v | df_eval σ f = Some v}.
- Proof.
- case_eq (df_eval σ f); simpl.
- - intros r ?? ; exists r; eauto.
- - intros ? inc.
- destruct (df_eval_complete' _ _ inc); congruence.
- Defined.
-
- Lemma df_eval_none (σ:df_env) (f:DefinedFunction) :
- df_eval σ f = None ->
- {v | In v (df_free_variables f) /\ ~ In v (domain σ)}.
- Proof.
- intros.
- destruct (incl_dec (df_free_variables f) (domain σ)).
- - destruct (df_eval_complete _ _ i); congruence.
- - apply (nincl_exists) in n; trivial.
- Qed.
-
- (* Either we can evaluate df or we are missing a variable definition.
- Note that this theorem may fail to hold if we change the definition of
- division to make it partial.
- *)
- Lemma df_eval_compute (σ:df_env) (f:DefinedFunction) :
- {v | df_eval σ f = Some v} + {x | In x (df_free_variables f) /\ ~ In x (domain σ)}.
- Proof.
- case_eq (df_eval σ f); simpl.
- - eauto.
- - intros H; apply df_eval_none in H; eauto.
- Defined.
-
- Lemma df_eval_closed (f:DefinedFunction) :
- df_closed f -> {v | df_eval nil f = Some v}.
- Proof.
- intros c.
- apply (df_eval_complete nil f).
- rewrite df_closed_nil by trivial.
- simpl; reflexivity.
- Defined.
-
- Lemma df_eval_lookup_on (σ₁ σ₂:df_env) (f:DefinedFunction) :
- lookup_equiv_on (df_free_variables f) σ₁ σ₂ ->
- df_eval σ₁ f = df_eval σ₂ f.
- Proof.
- intros lookeq.
- induction f; simpl in *; trivial
- ; try solve [apply lookup_equiv_on_dom_app in lookeq; intuition
- ; rewrite H1, H2; trivial
- | rewrite IHf; trivial].
- - apply lookeq; simpl; tauto.
- Qed.
-*)
- End fv.
-
- Section apply.
-
- Fixpoint df_apply {T} (e: DefinedFunction UnitAnn T)
- (args: forall (v:var_type), DefinedFunction UnitAnn (snd v)) : DefinedFunction UnitAnn T :=
- match e with
- | Number _ x => Number tt x
- | Constant t _ x => Constant tt x
- | DVector n _ df => DVector tt (fun x => df_apply (df x) args)
- | DMatrix n m _ df => DMatrix tt (fun i j => df_apply (df i j) args)
- | Var v _ => args v
- | Plus _ l r => Plus tt (df_apply l args) (df_apply r args)
- | Times _ l r => Times tt (df_apply l args) (df_apply r args)
- | Minus _ l r => Minus tt (df_apply l args) (df_apply r args)
- | Divide _ l r => Divide tt (df_apply l args) (df_apply r args)
- | Square _ e => Square tt (df_apply e args)
- | Exp _ e => Exp tt (df_apply e args)
- | Log _ e => Log tt (df_apply e args)
- | Abs _ e => Abs tt (df_apply e args)
- | Sign _ e => Sign tt (df_apply e args)
- | PSign _ e => PSign tt (df_apply e args)
- | Max _ l r => Max tt (df_apply l args) (df_apply r args)
- | VectorElem n _ l i => VectorElem tt (df_apply l args) i
- | MatrixElem m n _ l i j => MatrixElem tt (df_apply l args) i j
- | VectorDot n _ l r => VectorDot tt (df_apply l args) (df_apply r args)
- | VectorSum n _ l => VectorSum tt (df_apply l args)
- | MatrixSum n m _ l => MatrixSum tt (df_apply l args)
- | VectorScalMult n _ x r => VectorScalMult tt (df_apply x args) (df_apply r args)
- | MatrixScalMult n m _ x r => MatrixScalMult tt (df_apply x args) (df_apply r args)
- | MatrixVectorMult n m _ l r => MatrixVectorMult tt (df_apply l args) (df_apply r args)
- | MatrixVectorAdd n m _ l r => MatrixVectorAdd tt (df_apply l args) (df_apply r args)
- | MatrixMult n m p _ l r => MatrixMult tt (df_apply l args) (df_apply r args)
- | VectorPlus n _ l r => VectorPlus tt (df_apply l args) (df_apply r args)
- | VectorMinus n _ l r => VectorMinus tt (df_apply l args) (df_apply r args)
- | MatrixPlus n m _ l r => MatrixPlus tt (df_apply l args) (df_apply r args)
- | MatrixMinus n m _ l r => MatrixMinus tt (df_apply l args) (df_apply r args)
- | VectorApply n _ x s l => VectorApply tt x (df_apply s args) (df_apply l args)
- | MatrixApply n m _ x s l => MatrixApply tt x (df_apply s args) (df_apply l args)
- | VLossfun n _ v1 v2 s l r => VLossfun tt v1 v2 (df_apply s args) (df_apply l args) r
- | MLossfun n m _ v1 v2 s l r => MLossfun tt v1 v2 (df_apply s args) (df_apply l args) r
- end.
-
- End apply.
-
-End DefinedFunctions.
-
-Tactic Notation "DefinedFunction_cases" tactic(first) ident(c) :=
- first;
- [ Case_aux c "Number"%string
- | Case_aux c "Constant"%string
- | Case_aux c "DVector"%string
- | Case_aux c "DMatrix"%string
- | Case_aux c "Var"%string
- | Case_aux c "Plus"%string
- | Case_aux c "Minus"%string
- | Case_aux c "Times"%string
- | Case_aux c "Divide"%string
- | Case_aux c "Square"%string
- | Case_aux c "Exp"%string
- | Case_aux c "Log"%string
- | Case_aux c "Abs"%string
- | Case_aux c "Sign"%string
- | Case_aux c "PSign"%string
- | Case_aux c "Max"%string
- | Case_aux c "VectorDot"%string
- | Case_aux c "VectorSum"%string
- | Case_aux c "MatrixSum"%string
- | Case_aux c "VectorElem"%string
- | Case_aux c "MatrixElem"%string
- | Case_aux c "MatrixVectorMult"%string
- | Case_aux c "MatrixVectorAdd"%string
- | Case_aux c "MatrixMult"%string
- | Case_aux c "VectorPlus"%string
- | Case_aux c "VectorMinus"%string
- | Case_aux c "MatrixPlus"%string
- | Case_aux c "MatrixMinus"%string
- | Case_aux c "VectorScalMult"%string
- | Case_aux c "MatrixScalMult"%string
- | Case_aux c "VectorApply"%string
- | Case_aux c "MatrixApply"%string
- | Case_aux c "VLossfun"%string
- | Case_aux c "MLossfun"%string].
-
-Ltac refl_simpler :=
- repeat
- match goal with
- | [H: @eq var_type _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H)
- | [H: @equiv var_type _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H)
- | [H: @eq definition_function_types _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H)
- | [H: @equiv definition_function_types _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H)
- end.
-
-Section real_pfs.
-
- Local Existing Instance floatish_R.
- Import Reals.
- Import List.
-
- Lemma MaxDerivedMax_eq (a b : DefinedFunction UnitAnn DTfloat) :
- forall σ, df_eval σ (Max tt a b) = df_eval σ (MaxDerived a b).
- Proof.
- simpl; intros σ.
- destruct (df_eval σ a); destruct (df_eval σ b); trivial.
- f_equal.
- autorewrite with Rarith in *.
- destruct (Rle_dec d d0).
- - rewrite Rmax_right by trivial.
- rewrite Rabs_pos_eq by lra.
- lra.
- - rewrite Rmax_left by lra.
- rewrite Rabs_minus_sym.
- rewrite Rabs_pos_eq by lra.
- lra.
- Qed.
-
-(* Lemma coerce_dec_id {A} (dec:forall x y:A, {x=y}+{x<>y}) (x:A) (pf:x=x) : coerce pf x = x.
- Proof.
- unfold coerce.
- replace pf with (eq_refl A); trivial.
- apply UIP_dec.
- apply dec.
- generalize (@UIP_dec A dec pf).
- Lemma var_type_UIP_refl {x:var_type} (e:x=x) : e = eq_refl x.
- Proof.
- apply (UIP_dec vart_dec x x pf).
- Qed.
-
- unfold coerce.
- destruct pf.
- destruct pf.
- exact a.
- Defined.
-*)
-
-Tactic Notation "DefinedFunction_scalar_cases" tactic(first) ident(c) :=
- first;
- [ Case_aux c "Number"%string
- | Case_aux c "Constant"%string
- | Case_aux c "Var"%string
- | Case_aux c "Plus"%string
- | Case_aux c "Minus"%string
- | Case_aux c "Times"%string
- | Case_aux c "Divide"%string
- | Case_aux c "Square"%string
- | Case_aux c "Exp"%string
- | Case_aux c "Log"%string
- | Case_aux c "Abs"%string
- | Case_aux c "Sign"%string
- | Case_aux c "PSign"%string
- | Case_aux c "Max"%string].
-
-
- Lemma backpropeq_gen (x : SubVar) (env gradenv : df_env) (dfexpr : DefinedFunction UnitAnn DTfloat) (grad : float) :
- let xvar := (x, DTfloat) in
- is_scalar_function dfexpr ->
- vartlookup gradenv (x,DTfloat) <> None ->
- match df_eval_deriv env dfexpr xvar,
- backprop_lookup (Some gradenv) xvar,
- backprop_lookup (df_eval_backprop_deriv env dfexpr gradenv grad) xvar
- with
- | Some dval, Some bval0, Some bval1 => (dval*grad + bval0)%R = bval1
- | None, _, None => True
- | _, _, _ => False
- end.
- Proof.
- simpl.
- intros is_scalar.
- generalize is_scalar.
- revert grad gradenv.
- pattern dfexpr.
- revert dfexpr is_scalar.
- DefinedFunction_scalar_cases (apply is_scalar_function_ind) Case; simpl.
- - Case "Number"%string.
- intros _ _ grad gradenv _ xinn.
- destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra.
- - Case "Constant"%string.
- intros _ _ grad gradenv _ xinn.
- destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra.
- - Case "Var"%string.
- intros sv _ grad gradenv _ xinn.
- case_eq (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto].
- destruct (var_dec x sv); simpl.
- + subst.
- rewrite H; simpl.
- rewrite lookup_update.
- destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (sv, DTfloat)); [| congruence].
- unfold addvar; simpl.
- rewrite H.
- lra.
- + destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (x, DTfloat)); [congruence | ].
- case_eq (vartlookup gradenv (sv, DTfloat)); simpl; intros.
- * rewrite lookup_update_neq by congruence.
- rewrite H.
- lra.
- * rewrite H.
- lra.
- - Case "Plus"%string.
- intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn.
- case_eq (df_eval_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl grad gradenv isc1 xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_backprop_deriv env l gradenv grad)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr grad ge' isc2).
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_backprop_deriv env r ge' grad) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_backprop_deriv env l gradenv grad ); simpl; trivial; intros.
- apply IHr; trivial.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl grad gradenv isc1 xinn).
- case_eq (df_eval_backprop_deriv env l gradenv grad); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Minus"%string.
- intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn.
- case_eq (df_eval_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl grad gradenv isc1 xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_backprop_deriv env l gradenv grad)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr (- grad)%R ge' isc2).
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_backprop_deriv env r ge' (- grad)%R) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_backprop_deriv env l gradenv grad ); simpl; trivial; intros.
- apply IHr; trivial.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl grad gradenv isc1 xinn).
- case_eq (df_eval_backprop_deriv env l gradenv grad); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Times"%string.
- intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn.
- case_eq (df_eval env l);
- [ intros le eqle | intros eqle]; simpl; trivial.
- case_eq (df_eval_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval env r);
- [ intros re eqre | intros eqre]
- ; simpl; trivial.
- case_eq (df_eval_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl (re * grad)%R gradenv isc1 xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_backprop_deriv env l gradenv (re * grad)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr (le * grad)%R ge' isc2).
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_backprop_deriv env r ge' (le * grad)%R) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_backprop_deriv env l gradenv (re * grad)%R ); simpl; trivial; intros.
- apply IHr; trivial.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + case_eq (df_eval env r);
- [ intros re eqre | intros eqre]
- ; simpl; trivial.
- specialize (IHl (re * grad)%R gradenv isc1 xinn).
- case_eq (df_eval_backprop_deriv env l gradenv (re * grad)%R); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Divide"%string.
- intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn.
- case_eq (df_eval env l);
- [ intros le eqle | intros eqle]; simpl; trivial.
- case_eq (df_eval_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval env r);
- [ intros re eqre | intros eqre]
- ; simpl; trivial.
- case_eq (df_eval_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl (grad / re)%R gradenv isc1 xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_backprop_deriv env l gradenv (grad / re)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr (- le / (re * re) * grad)%R ge' isc2).
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_backprop_deriv env r ge' (- le / (re * re) * grad)%R) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_backprop_deriv env l gradenv (grad / re)%R ); simpl; trivial; intros.
- apply IHr; trivial.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + case_eq (df_eval env r);
- [ intros re eqre | intros eqre]
- ; simpl; trivial.
- specialize (IHl (grad / re)%R gradenv isc1 xinn).
- case_eq (df_eval_backprop_deriv env l gradenv (grad / re)%R); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Square"%string.
- intros _ e IHe grad gradenv isc xinn.
- case_eq (df_eval env e);
- [ intros le eqee | intros eqee]; simpl; trivial.
-
- specialize (IHe (2 * le * grad)%R gradenv isc xinn).
- case_eq (df_eval_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
-
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_backprop_deriv env e gradenv (2 * le * grad)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
- - Case "Exp"%string.
- intros _ e IHe grad gradenv isc xinn.
- case_eq (df_eval env e);
- [ intros le eqee | intros eqee]; simpl; trivial.
-
- specialize (IHe (grad * exp le)%R gradenv isc xinn).
- case_eq (df_eval_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
-
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_backprop_deriv env e gradenv (grad * exp le)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
- - Case "Log"%string.
- intros _ e IHe grad gradenv isc xinn.
- case_eq (df_eval env e);
- [ intros le eqee | intros eqee]; simpl; trivial.
-
- specialize (IHe (grad / le)%R gradenv isc xinn).
- case_eq (df_eval_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
-
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_backprop_deriv env e gradenv (grad / le)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
- - Case "Abs"%string.
- intros _ e IHe grad gradenv isc xinn.
- case_eq (df_eval env e);
- [ intros le eqee | intros eqee]; simpl; trivial.
-
- specialize (IHe (grad * (sign le))%R gradenv isc xinn).
- case_eq (df_eval_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
-
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_backprop_deriv env e gradenv (grad * (sign le))%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
- - Case "Sign"%string.
- intros _ e IHe grad gradenv isc xinn.
- specialize (IHe 0%R gradenv isc xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (df_eval_backprop_deriv env e gradenv 0%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- simpl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- lra.
- - Case "PSign"%string.
- intros _ e IHe grad gradenv isc xinn.
- specialize (IHe 0%R gradenv isc xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (df_eval_backprop_deriv env e gradenv 0%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- simpl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- lra.
- - Case "Max"%string.
- intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn.
- specialize (IHl grad gradenv isc1 xinn).
- specialize (IHr grad gradenv isc2 xinn).
-
- case_eq (df_eval env l); simpl; trivial
- ; intros eld eqeld.
- case_eq (df_eval env r ); simpl; intros; trivial.
- destruct (Rle_dec eld d); simpl.
- + destruct (df_eval_deriv env r (x, DTfloat)); simpl; trivial.
- + destruct (df_eval_deriv env l (x, DTfloat)); simpl; trivial.
- Qed.
-
- (*
-
- Lemma tree_backpropeq_gen (x : SubVar) (env gradenv : df_env)
- (dfexpr : DefinedFunction EvalAnn DTfloat) (grad : float) :
- let xvar := (x, DTfloat) in
- is_scalar_function dfexpr ->
- vartlookup gradenv (x,DTfloat) <> None ->
- match df_eval_tree_deriv env dfexpr xvar,
- backprop_lookup (Some gradenv) xvar,
- backprop_lookup (df_eval_tree_backprop_deriv env dfexpr gradenv grad) xvar
- with
- | Some dval, Some bval0, Some bval1 => (dval*grad + bval0)%R = bval1
- | None, _, None => True
- | _, _, _ => False
- end.
- Proof.
- simpl.
- intros is_scalar.
- revert grad gradenv.
- pattern dfexpr.
- revert dfexpr is_scalar.
- DefinedFunction_scalar_cases (apply is_scalar_function_ind) Case; simpl.
-
- - Case "Number"%string.
- intros _ _ grad gradenv xinn.
- destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra.
- - Case "Constant"%string.
- intros _ _ grad gradenv xinn.
- destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra.
- - Case "Var"%string.
- intros sv _ grad gradenv xinn.
- case_eq (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto].
- destruct (var_dec x sv); simpl.
- + subst.
- rewrite H; simpl.
- rewrite lookup_update.
- destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (sv, DTfloat)); [| congruence].
- unfold addvar; simpl.
- rewrite H.
- lra.
- + destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (x, DTfloat)); [congruence | ].
- case_eq (vartlookup gradenv (sv, DTfloat)); simpl; intros.
- * rewrite lookup_update_neq by congruence.
- rewrite H.
- lra.
- * rewrite H.
- lra.
- - Case "Plus"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- case_eq (df_eval_tree_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_tree_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl grad gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env l gradenv grad)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr grad ge').
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' grad) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_tree_backprop_deriv env l gradenv grad ); simpl; trivial; intros.
- apply IHr.
- apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl grad gradenv xinn).
- case_eq (df_eval_tree_backprop_deriv env l gradenv grad); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Minus"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- case_eq (df_eval_tree_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_tree_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl grad gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env l gradenv grad)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr (- grad)%R ge').
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (- grad)%R) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_tree_backprop_deriv env l gradenv grad ); simpl; trivial; intros.
- apply IHr.
- apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl grad gradenv xinn).
- case_eq (df_eval_tree_backprop_deriv env l gradenv grad); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Times"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- case_eq (df_eval_tree_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_tree_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl (get_annotation r * grad)%R gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr (get_annotation l * grad)%R ge').
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (get_annotation l * grad)%R) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R ); simpl; trivial; intros.
- apply IHr.
- apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl (get_annotation r * grad)%R gradenv xinn).
- case_eq (df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Divide"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- case_eq (df_eval_tree_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_tree_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl (grad / get_annotation r)%R gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr (- get_annotation l / ((get_annotation r) * (get_annotation r)) * grad)%R ge').
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (- get_annotation l / ((get_annotation r) * (get_annotation r)) * grad)%R) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r)%R ); simpl; trivial; intros.
- apply IHr.
- apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl (grad / get_annotation r)%R gradenv xinn).
- case_eq (df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r )%R); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Square"%string.
- intros _ e IHe grad gradenv xinn.
-
- specialize (IHe (2 * (get_annotation e) * grad)%R gradenv xinn).
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
-
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env e gradenv (2 * (get_annotation e) * grad)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
- - Case "Exp"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe (grad * exp (get_annotation e))%R gradenv xinn).
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env e gradenv (grad * exp (get_annotation e))%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
-
- - Case "Log"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe (grad / get_annotation e)%R gradenv xinn).
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env e gradenv (grad / get_annotation e)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
-
- - Case "Abs"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe (grad * (sign (get_annotation e)))%R gradenv xinn).
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env e gradenv (grad * (sign (get_annotation e)))%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
- - Case "Sign"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe 0%R gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (df_eval_tree_backprop_deriv env e gradenv 0%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- simpl.
- replace (de * 0)%R with (0)%R in IHe by lra.
- replace (0 * grad)%R with (0)%R by lra.
- apply IHe.
- - Case "PSign"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe 0%R gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (df_eval_tree_backprop_deriv env e gradenv 0%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- simpl.
- replace (de * 0)%R with (0)%R in IHe by lra.
- replace (0 * grad)%R with (0)%R by lra.
- apply IHe.
- - Case "Max"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- specialize (IHl grad gradenv xinn).
- specialize (IHr grad gradenv xinn).
- destruct (Rle_dec (get_annotation l) (get_annotation r)); simpl.
- destruct (df_eval_tree_deriv env r (x, DTfloat)); simpl; trivial.
- destruct (df_eval_tree_deriv env l (x, DTfloat)); simpl; trivial.
- Qed.
- *)
-
- Lemma eval_fully_closed_not_none {T} (σ:df_env) (df:DefinedFunction UnitAnn T) :
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl -> df_eval σ df <> None.
- Proof.
- revert σ.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case
- ; simpl; intros;
- try solve
- [congruence
- |
- destruct H; simpl in *;
- specialize (IHdf1 σ); specialize (IHdf2 σ);
- match_option; [|tauto];
- cut_to IHdf2;
- [ match_option; tauto | easy]
- |
- specialize (IHdf σ); simpl in IHdf;
- cut_to IHdf;
- [ match_option; tauto | easy]
- ].
- - Case "DVector"%string.
- apply vectoro_to_ovector_not_none; intro.
- specialize (H i σ); simpl in H; apply H.
- rewrite vforall_forall in H0.
- now specialize (H0 i).
- - Case "DMatrix"%string.
- unfold matrixo_to_omatrix.
- apply vectoro_to_ovector_not_none; intro.
- apply vectoro_to_ovector_not_none; intro.
- specialize (H i i0 σ); simpl in H; apply H.
- rewrite vforall_forall in H0; specialize (H0 i).
- rewrite vforall_forall in H0; now specialize (H0 i0).
- - Case "Var"%string.
- induction σ.
- + simpl in H; tauto.
- + simpl in *.
- match_case; intros.
- destruct H.
- * congruence.
- * now apply IHσ.
- - Case "VectorApply"%string.
- destruct H; simpl in *.
- specialize (IHdf2 σ).
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- apply vectoro_to_ovector_not_none; intro.
- specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i) :: nil)).
- now apply IHdf1.
- - Case "MatrixApply"%string.
- destruct H; simpl in *.
- specialize (IHdf2 σ).
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- unfold matrixo_to_omatrix.
- apply vectoro_to_ovector_not_none; intro.
- apply vectoro_to_ovector_not_none; intro.
- specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i i0) :: nil)).
- now apply IHdf1.
- - Case "VLossfun"%string.
- destruct H; simpl in *.
- specialize (IHdf2 σ).
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- match_option.
- apply vectoro_to_ovector_not_none in eqq0.
- + tauto.
- + intros.
- specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i) :: mk_env_entry (v2, DTfloat) (r i) :: nil)).
- now apply IHdf1.
- - Case "MLossfun"%string.
- destruct H; simpl in *.
- specialize (IHdf2 σ).
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- unfold matrixo_to_omatrix.
- match_option.
- apply vectoro_to_ovector_not_none in eqq0.
- + tauto.
- + intros.
- apply vectoro_to_ovector_not_none; intros.
- specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i i0) :: mk_env_entry (v2, DTfloat) (r i i0) :: nil)).
- now apply IHdf1.
- Qed.
-
- Lemma eval_fully_closed_total {T} (σ:df_env) (df:DefinedFunction UnitAnn T) :
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl ->
- {d:definition_function_types_interp T | df_eval σ df = Some d}.
- Proof.
- intros.
- case_eq (df_eval σ df); intros.
- - now exists d.
- - generalize (eval_fully_closed_not_none σ df).
- intros; simpl in *.
- cut_to H1; tauto.
- Qed.
-
- Lemma closed_over_cons {T} (df:DefinedFunction UnitAnn T) (v:var_type) (vl : list var_type):
- df_closed_over df vl -> df_closed_over df (v::vl).
- Proof.
- unfold df_closed_over.
- intros.
- apply incl_tl.
- apply H.
- Qed.
-
- Lemma fully_closed_over_cons {T} (df:DefinedFunction UnitAnn T) (v:var_type)
- (vl : list var_type):
- fully_closed_over df vl -> fully_closed_over df (v::vl).
- Proof.
- intros.
- apply (fully_closed_over_map df vl (fun vl1 => cons v vl1)); trivial.
- unfold In_compat_map.
- intros.
- now apply in_cons.
- Qed.
-
- Lemma fully_closed_over_exchange_vars {T} (df:DefinedFunction UnitAnn T) (v1 v:var_type)
- (vl : list var_type):
- fully_closed_over df (v1 :: v :: vl) -> fully_closed_over df (v :: v1 :: vl).
- Proof.
- intros.
- apply (fully_closed_over_map df (v1 :: v :: vl)
- (fun vl1 => match vl1 with
- | a :: b :: vl2 => b :: a :: vl2
- | _ => vl1
- end )); trivial.
- unfold In_compat_map.
- intros.
- destruct vl0; trivial.
- destruct vl0; trivial.
- unfold In.
- unfold In in H0.
- tauto.
- Qed.
-
- Lemma fully_closed_over_singleton {T} (df:DefinedFunction UnitAnn T) (v:var_type)
- (vl : list var_type):
- fully_closed_over df (v::nil) -> fully_closed_over df (v::vl).
- Proof.
- intros.
- induction vl; trivial.
- apply fully_closed_over_exchange_vars.
- now apply fully_closed_over_cons.
- Qed.
-
- Lemma fully_closed_over_exchange_2vars {T} (df:DefinedFunction UnitAnn T)
- (v1 v2 v:var_type) (vl : list var_type):
- fully_closed_over df (v1 :: v2 :: v:: vl) -> fully_closed_over df (v :: v1 :: v2 :: vl).
- Proof.
- intros.
- apply (fully_closed_over_map df (v1 :: v2 :: v :: vl)
- (fun vl1 => match vl1 with
- | a :: b :: c :: vl2 => c :: a :: b :: vl2
- | _ => vl1
- end )); trivial.
- unfold In_compat_map.
- intros.
- destruct vl0; trivial.
- destruct vl0; trivial.
- destruct vl0; trivial.
- unfold In.
- unfold In in H0.
- tauto.
- Qed.
-
- Lemma fully_closed_over_pair {T} (df:DefinedFunction UnitAnn T) (v1 v2:var_type)
- (vl : list var_type):
- fully_closed_over df (v1::v2::nil) -> fully_closed_over df (v1::v2::vl).
- Proof.
- intros.
- induction vl; trivial.
- apply fully_closed_over_exchange_2vars.
- apply fully_closed_over_exchange_2vars.
- now apply fully_closed_over_cons.
- Qed.
-
- Lemma fully_closed_subst {T} (vl:list var_type) (df:DefinedFunction UnitAnn T) (v:var_type)
- (e':DefinedFunction UnitAnn (snd v)):
- fully_closed_over df (v::vl) ->
- fully_closed_over e' vl ->
- fully_closed_over (df_subst df v e') vl.
- Proof.
- revert vl.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case
- ; simpl; intros; try solve
- [easy
- |
- apply IHdf; [apply H | apply H0]
- |
- split; destruct H; simpl in *;
- [ apply IHdf1; [apply H | apply H0]
- | apply IHdf2; [apply H1 | apply H0]]
- ].
- - Case "DVector"%string.
- apply vforall_forall; intros; simpl in *.
- specialize (H i); apply H.
- rewrite vforall_forall in H0.
- specialize (H0 i).
- apply H0.
- apply H1.
- - Case "DMatrix"%string.
- apply vforall_forall; intros.
- apply vforall_forall; intros; simpl in *.
- specialize (H i i0); apply H.
- rewrite vforall_forall in H0; specialize (H0 i).
- rewrite vforall_forall in H0; specialize (H0 i0).
- apply H0.
- apply H1.
- - Case "Var"%string.
- unfold substvar.
- destruct H.
- + subst.
- unfold substvar.
- match_destr; [ | congruence].
- refl_simpler.
- simpl; trivial.
- + destruct v; destruct v0.
- simpl in *.
- match_destr.
- red in e; subst.
- simpl; trivial.
- - Case "VectorApply"%string.
- destruct H; split; trivial.
- + apply IHdf2.
- * apply H1.
- * apply H0.
- - Case "MatrixApply"%string.
- destruct H; split; trivial.
- + apply IHdf2.
- * apply H1.
- * apply H0.
- - Case "VLossfun"%string.
- destruct H; split; trivial.
- + apply IHdf2.
- * apply H1.
- * apply H0.
- - Case "MLossfun"%string.
- destruct H; split; trivial.
- + apply IHdf2.
- * apply H1.
- * apply H0.
- Qed.
-
- Lemma fully_closed_deriv {T} (df:DefinedFunction UnitAnn T) (s:SubVar)
- (vl : list var_type):
- fully_closed_over df vl ->
- fully_closed_over (df_deriv df (s, DTfloat)) vl.
- Proof.
- revert s; revert vl.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case
- ; simpl; intros; try solve
- [easy
- |
- apply IHdf,H
- |
- split; try easy; now apply IHdf
- |
- destruct H; repeat split; try easy;
- [apply IHdf2, H0 | apply IHdf1, H]
- |
- destruct H; repeat split; try easy;
- [ apply IHdf1, H | apply IHdf2, H0]
- ].
- - Case "DVector"%string.
- apply vforall_forall; intros.
- apply H.
- rewrite vforall_forall in H0.
- now apply H0.
- - Case "DMatrix"%string.
- apply vforall_forall; intros.
- apply vforall_forall; intros.
- apply H.
- rewrite vforall_forall in H0.
- specialize (H0 i).
- rewrite vforall_forall in H0.
- now specialize (H0 i0).
- - Case "Max"%string.
- destruct H; repeat split; try easy.
- apply IHdf2, H0.
- apply IHdf1, H.
- apply IHdf2, H0.
- apply IHdf1, H.
- - Case "VectorApply"%string.
- apply vforall_forall; intros; simpl in *.
- split; destruct H.
- apply IHdf2, H0.
- apply fully_closed_subst.
- + apply IHdf1.
- now apply fully_closed_over_singleton.
- + simpl; apply H0.
- - Case "MatrixApply"%string.
- apply vforall_forall; intros; simpl in *.
- apply vforall_forall; intros; simpl in *.
- split; destruct H.
- apply IHdf2, H0.
- apply fully_closed_subst.
- + apply IHdf1.
- now apply fully_closed_over_singleton.
- + simpl; apply H0.
- - Case "VLossfun"%string.
- intros; simpl in *.
- split; destruct H.
- apply IHdf2, H0.
- apply vforall_forall; intros; simpl in *.
- apply fully_closed_subst.
- apply fully_closed_subst.
- + apply IHdf1.
- now apply fully_closed_over_pair.
- + simpl; apply fully_closed_over_cons; apply H0.
- + now simpl.
- - Case "MLossfun"%string.
- intros; simpl in *.
- apply vforall_forall; intros; simpl in *.
- apply vforall_forall; intros; simpl in *.
- destruct H; split; trivial.
- split.
- apply IHdf2, H0.
- apply fully_closed_subst.
- apply fully_closed_subst.
- + apply IHdf1.
- now apply fully_closed_over_pair.
- + simpl; apply fully_closed_over_cons; apply H0.
- + now simpl.
- Qed.
-
- Lemma list_env_iter_total_fun {A} (f : A -> df_env -> option df_env) (env : df_env) (l : list A) :
- (forall (a:A) (env0: df_env), (f a env0) <> None) ->
- list_env_iter f (Some env) l <> None.
- Proof.
- intros.
- generalize env.
- induction l; [simpl; congruence|].
- simpl; intros.
- specialize (H a env0).
- case_eq (f a env0).
- - intros; apply (IHl d).
- - tauto.
- Qed.
-
- Lemma backprop_deriv_fully_closed_not_none {T} (σ:df_env) (df:DefinedFunction UnitAnn T)
- (grad_env:df_env) (grad: definition_function_types_interp T):
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl -> df_eval_backprop_deriv σ df grad_env grad <> None.
- Proof.
- revert grad_env.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case
- ; simpl; intros;
- try solve [congruence
- |
- specialize (IHdf1 grad grad_env); simpl in IHdf1; destruct H;
- match_option; [|tauto];
- specialize (IHdf2 grad d); simpl in IHdf2;
- now apply IHdf2
- ].
- - Case "DVector"%string.
- unfold two_vector_env_iter_alt.
- rewrite vforall_forall in H0.
- apply (list_env_iter_total_fun
- (fun i env => df_eval_backprop_deriv σ (x i) env (grad i))
- grad_env (bounded_seq0 n)).
- intros.
- apply (H a (grad a) env0).
- apply (H0 a).
- - Case "DMatrix"%string.
- unfold two_matrix_env_iter_alt.
- rewrite vforall_forall in H0.
- apply (list_env_iter_total_fun
- (fun i env =>
- list_env_iter
- (fun j env0 =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (Some env) (bounded_seq0 m))
- grad_env (bounded_seq0 n)).
- intros.
- apply (list_env_iter_total_fun
- (fun j env => df_eval_backprop_deriv σ (x a j) env (grad a j))
- env0 (bounded_seq0 m)).
- intros.
- apply (H a a0 (grad a a0) env1).
- specialize (H0 a).
- rewrite vforall_forall in H0.
- apply (H0 a0).
- - Case "Var"%string.
- match_destr.
- - Case "Minus"%string.
- specialize (IHdf1 grad grad_env); simpl in IHdf1; destruct H;
- match_option; [|tauto];
- specialize (IHdf2 (-grad)%R d); simpl in IHdf2;
- now apply IHdf2.
- - Case "Times"%string.
- destruct H.
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2.
- + match_option; [|tauto].
- specialize (IHdf1 (d0 * grad)%R grad_env); simpl in IHdf1.
- match_option; [|tauto].
- specialize (IHdf2 (d * grad)%R d1); simpl in IHdf2.
- now apply IHdf2.
- - Case "Divide"%string.
- destruct H.
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2.
- + match_option; [|tauto].
- specialize (IHdf1 (grad / d0)%R grad_env); simpl in IHdf1.
- match_option; [|tauto].
- specialize (IHdf2 (-d / (d0 * d0) * grad)%R d1); simpl in IHdf2.
- now apply IHdf2.
- - Case "Square"%string.
- generalize (eval_fully_closed_not_none σ df); intros; simpl in H0.
- match_option; [|tauto].
- specialize (IHdf (2 * d * grad)%R grad_env); simpl in IHdf.
- now apply IHdf.
- - Case "Exp"%string.
- generalize (eval_fully_closed_not_none σ df); intros; simpl in H0.
- match_option; [|tauto].
- specialize (IHdf (grad * exp d)%R grad_env); simpl in IHdf.
- now apply IHdf.
- - Case "Log"%string.
- generalize (eval_fully_closed_not_none σ df); intros; simpl in H0.
- match_option; [|tauto].
- specialize (IHdf (grad / d)%R grad_env); simpl in IHdf.
- now apply IHdf.
- - Case "Abs"%string.
- generalize (eval_fully_closed_not_none σ df); intros; simpl in H0.
- match_option; [|tauto].
- specialize (IHdf (grad * sign d)%R grad_env); simpl in IHdf.
- now apply IHdf.
- - Case "Sign"%string.
- specialize (IHdf (0)%R grad_env); simpl in IHdf.
- now apply IHdf.
- - Case "PSign"%string.
- specialize (IHdf (0)%R grad_env); simpl in IHdf.
- now apply IHdf.
- - Case "Max"%string.
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- match_option; [|tauto].
- match_destr.
- specialize (IHdf2 grad grad_env).
- now apply IHdf2.
- specialize (IHdf1 grad grad_env).
- now apply IHdf1.
- - Case "VectorDot"%string.
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- match_option; [|tauto].
- specialize (IHdf1 (vmap (fun rv : R => (rv * grad)%R) d0) grad_env); simpl in IHdf1.
- match_option; [|tauto].
- specialize (IHdf2 (vmap (fun lv : R => (lv * grad)%R) d) d1).
- now apply IHdf2.
- - Case "VectorSum"%string.
- specialize (IHdf (ConstVector n grad) grad_env).
- now apply IHdf.
- - Case "MatrixSum"%string.
- specialize (IHdf (ConstMatrix m n grad) grad_env).
- now apply IHdf.
- - Case "VectorElem"%string.
- specialize (IHdf (fun k : {n' : nat | n' < n} => if equiv_dec (` k) (` i) then grad else 0%R) grad_env).
- now apply IHdf.
- - Case "MatrixElem"%string.
- specialize (IHdf (fun (k1 : {n' : nat | n' < m})
- (k2 : {m' : nat | m' < n}) =>
- if equiv_dec (` k1) (` i) then if
- equiv_dec (` k2) (` j) then grad else 0%R else 0%R)
- grad_env).
- now apply IHdf.
- - Case "MatrixVectorMult"%string.
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- match_option; [|tauto].
- specialize (IHdf1 (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => (grad i * d0 j)%R)
- grad_env); simpl in IHdf1.
- match_option; [|tauto].
- specialize (IHdf2 (matrix_vector_mult
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => d j i)
- grad)
- d1).
- now apply IHdf2.
- - Case "MatrixVectorAdd"%string.
- specialize (IHdf1 grad grad_env); simpl in IHdf1.
- match_option; [|tauto].
- match_option.
- generalize (list_env_iter_total_fun
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (@transpose R m n grad i))
- d (bounded_seq0 n)); intros.
- cut_to H0; [congruence|].
- intros; destruct H.
- now apply IHdf2.
- - Case "MatrixMult"%string.
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- match_option; [|tauto].
- specialize (IHdf1 (matrix_mult grad (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < p}) => d0 j i))
- grad_env); simpl in IHdf1.
- match_option; [|tauto].
- specialize (IHdf2 (matrix_mult (fun (i : {n' : nat | n' < p})
- (j : {m' : nat | m' < m}) => d j i) grad)
- d1).
- now apply IHdf2.
- - Case "VectorMinus"%string.
- specialize (IHdf1 grad grad_env); simpl in IHdf1.
- match_option; [|tauto].
- specialize (IHdf2 (fun i : {n' : nat | n' < n} => (- grad i)%R) d).
- now apply IHdf2.
- - Case "MatrixMinus"%string.
- specialize (IHdf1 grad grad_env); simpl in IHdf1.
- match_option; [|tauto].
- specialize (IHdf2 (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (- grad i j)%R)
- d).
- now apply IHdf2.
- - Case "VectorScalMult"%string.
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- match_option; [|tauto].
- specialize (IHdf1 (vsum (fun j : {n' : nat | n' < n} => (d0 j * grad j)%R))
- grad_env); simpl in IHdf1.
- match_option; [|tauto].
- specialize (IHdf2 (fun j : {n' : nat | n' < n} => (d * grad j)%R) d1).
- now apply IHdf2.
- - Case "MatrixScalMult"%string.
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- match_option; [|tauto].
- specialize (IHdf1 (msum
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (d0 i j * grad i j)%R))
- grad_env); simpl in IHdf1.
- match_option; [|tauto].
- specialize (IHdf2 (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad i j * d)%R)
- d1).
- now apply IHdf2.
- - Case "VectorApply"%string.
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H0.
- match_option; [|tauto].
- match_option.
- + specialize (IHdf2 v0 grad_env).
- now apply IHdf2.
- + apply vectoro_to_ovector_exists_None in eqq0.
- destruct eqq0.
- rewrite vmap_nth in e; simpl in e.
- destruct H.
- match_option_in e.
- generalize (fully_closed_deriv df1 v ((v,DTfloat):: nil)); intros.
- cut_to H2; trivial.
- generalize (eval_fully_closed_not_none (mk_env_entry (v, DTfloat) (d x) :: nil)
- (df_deriv df1 (v, DTfloat))); intros.
- simpl in H3; cut_to H3; tauto.
- - Case "MatrixApply"%string.
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H0.
- match_option; [|tauto].
- match_option.
- + specialize (IHdf2 m0 grad_env).
- now apply IHdf2.
- + unfold matrixo_to_omatrix in eqq0.
- apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0.
- apply vectoro_to_ovector_exists_None in e; destruct e.
- unfold mmap in e.
- rewrite vmap_nth in e; simpl in e.
- rewrite vmap_nth in e; simpl in e.
- destruct H.
- unfold matrix_zip in e.
- rewrite vmap_nth in e; simpl in e.
- match_option_in e.
- generalize (fully_closed_deriv df1 v ((v,DTfloat):: nil)).
- intros; cut_to H2; trivial.
- generalize (eval_fully_closed_not_none (mk_env_entry (v, DTfloat) (d x x0) :: nil)
- (df_deriv df1 (v, DTfloat))); intros.
- simpl in H3; cut_to H3; tauto.
- - Case "VLossfun"%string.
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H0.
- match_option; [|tauto].
- match_option.
- + specialize (IHdf2 v grad_env).
- now apply IHdf2.
- + apply vectoro_to_ovector_exists_None in eqq0.
- destruct eqq0.
- rewrite vmap_nth in e; simpl in e.
- destruct H.
- match_option_in e.
- generalize (fully_closed_deriv df1 v1 ((v1,DTfloat)::(v2,DTfloat)::nil)).
- intros; cut_to H2; trivial.
- generalize (eval_fully_closed_not_none (mk_env_entry (v1, DTfloat) (d x) ::
- mk_env_entry (v2, DTfloat) (r x) :: nil)
- (df_deriv df1 (v1, DTfloat))); intros.
- simpl in H3; cut_to H3; tauto.
- - Case "MLossfun"%string.
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H0.
- match_option; [|tauto].
- match_option.
- + specialize (IHdf2 m0 grad_env).
- now apply IHdf2.
- + unfold matrixo_to_omatrix in eqq0.
- apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0.
- apply vectoro_to_ovector_exists_None in e; destruct e.
- unfold mmap in e.
- rewrite vmap_nth in e; simpl in e.
- rewrite vmap_nth in e; simpl in e.
- destruct H.
- unfold matrix_zip in e.
- rewrite vmap_nth in e; simpl in e.
- match_option_in e.
- generalize (fully_closed_deriv df1 v1 ((v1,DTfloat)::(v2,DTfloat)::nil)).
- intros; cut_to H2; trivial.
- generalize (eval_fully_closed_not_none (mk_env_entry (v1, DTfloat) (d x x0) ::
- mk_env_entry (v2, DTfloat) (r x x0) :: nil)
- (df_deriv df1 (v1, DTfloat))); intros.
- simpl in H3; cut_to H3; tauto.
- Qed.
-
- Lemma backprop_deriv_fully_closed_total {T} (σ:df_env) (df:DefinedFunction UnitAnn T)
- (grad_env:df_env) (grad: definition_function_types_interp T):
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl ->
- {d:df_env | df_eval_backprop_deriv σ df grad_env grad = Some d}.
- Proof.
- case_eq (df_eval_backprop_deriv σ df grad_env grad); intros.
- - now exists d.
- - generalize (backprop_deriv_fully_closed_not_none σ df grad_env grad).
- intros; simpl in *.
- cut_to H1; tauto.
- Qed.
-
- Lemma eval_deriv_fully_closed_not_none {T} (σ:df_env) (df:DefinedFunction UnitAnn T)
- (v:var_type):
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl -> df_eval_deriv σ df v <> None.
- Proof.
- revert σ v.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case
- ; intros; simpl in *;
- try solve [
- congruence
- |
- destruct H
- ; specialize (IHdf1 σ v); specialize (IHdf2 σ v)
- ;cut_to IHdf1; trivial
- ;match_option; [|tauto]
- ;cut_to IHdf2; trivial
- ;match_option; tauto
- |
- destruct H;
- specialize (IHdf1 σ v); specialize (IHdf2 σ v);
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1;
- cut_to H1; trivial;
- match_option; [|tauto];
- cut_to IHdf1; trivial;
- match_option; [|tauto];
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2;
- match_option; [|tauto];
- cut_to IHdf2; trivial;
- match_option; tauto
- |
- generalize (eval_fully_closed_not_none σ df); intros;
- specialize (IHdf σ v);
- simpl in H0; cut_to H0; trivial;
- match_option; [|tauto];
- cut_to IHdf; trivial;
- match_option; tauto
- |
- specialize (IHdf σ v);
- generalize (eval_fully_closed_not_none σ df); intros;
- simpl in H0; cut_to H0; trivial;
- match_option; tauto
- ].
- - Case "DVector"%string.
- apply vectoro_to_ovector_not_none; intros; apply H.
- rewrite vforall_forall in H0; apply H0.
- - Case "DMatrix"%string.
- apply vectoro_to_ovector_not_none; intros.
- apply vectoro_to_ovector_not_none; intros; apply H.
- rewrite vforall_forall in H0; specialize (H0 i).
- rewrite vforall_forall in H0; apply H0.
- - Case "Max"%string.
- destruct H;
- specialize (IHdf1 σ v); specialize (IHdf2 σ v);
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1;
- cut_to H1; trivial;
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2;
- match_option; [|tauto].
- case_eq ( Rle_dec d d0 ); intros.
- cut_to IHdf2; trivial.
- cut_to IHdf1; trivial.
- - Case "VectorApply"%string.
- destruct H.
- specialize (IHdf2 σ v0);
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- cut_to H1; trivial.
- match_option; [|tauto].
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- apply vectoro_to_ovector_not_none; intros.
- match_option.
- specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i) :: nil) (v, DTfloat)).
- cut_to IHdf1; trivial.
- tauto.
- - Case "MatrixApply"%string.
- destruct H.
- specialize (IHdf2 σ v0);
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- cut_to H1; trivial.
- match_option; [|tauto].
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- apply vectoro_to_ovector_not_none; intros.
- apply vectoro_to_ovector_not_none; intros.
- specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i i0) :: nil) (v, DTfloat)).
- cut_to IHdf1; trivial.
- match_option; tauto.
- - Case "VLossfun"%string.
- destruct H.
- specialize (IHdf2 σ v);
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- cut_to H1; trivial.
- match_option; [|tauto].
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- match_option.
- apply vectoro_to_ovector_not_none in eqq1.
- tauto.
- intros.
- specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i) :: mk_env_entry (v2, DTfloat) (r i) :: nil) (v1, DTfloat)).
- cut_to IHdf1; trivial.
- match_option; tauto.
- - Case "MLossfun"%string.
- destruct H.
- specialize (IHdf2 σ v);
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- cut_to H1; trivial.
- match_option; [|tauto].
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- match_option.
- unfold matrixo_to_omatrix in eqq1.
- apply vectoro_to_ovector_not_none in eqq1.
- tauto.
- intros.
- apply vectoro_to_ovector_not_none; intros.
- specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i i0) :: mk_env_entry (v2, DTfloat) (r i i0) :: nil)(v1, DTfloat)).
- cut_to IHdf1; trivial.
- match_option; tauto.
- Qed.
-
- Lemma eval_deriv_fully_closed_total {T} (σ:df_env) (df:DefinedFunction UnitAnn T)
- (v:var_type):
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl ->
- {d:definition_function_types_interp T | df_eval_deriv σ df v = Some d}.
- Proof.
- case_eq (df_eval_deriv σ df v); intros.
- - now exists d.
- - generalize (eval_deriv_fully_closed_not_none σ df v).
- intros; simpl in *.
- cut_to H1; tauto.
- Qed.
-
- Lemma eval_deriv_genvar_fully_closed_not_none {T} (σ:df_env) (df:DefinedFunction UnitAnn T)
- (v:df_env):
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl -> df_eval_deriv_genvar σ df v <> None.
- Proof.
- revert σ v.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case
- ; intros; simpl in *;
- try solve [
- congruence
- |
- destruct H
- ; specialize (IHdf1 σ v); specialize (IHdf2 σ v)
- ;cut_to IHdf1; trivial
- ;match_option; [|tauto]
- ;cut_to IHdf2; trivial
- ;match_option; tauto
- |
- destruct H;
- specialize (IHdf1 σ v); specialize (IHdf2 σ v);
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1;
- cut_to H1; trivial;
- match_option; [|tauto];
- cut_to IHdf1; trivial;
- match_option; [|tauto];
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2;
- match_option; [|tauto];
- cut_to IHdf2; trivial;
- match_option; tauto
- |
- generalize (eval_fully_closed_not_none σ df); intros;
- specialize (IHdf σ v);
- simpl in H0; cut_to H0; trivial;
- match_option; [|tauto];
- cut_to IHdf; trivial;
- match_option; tauto
- |
- specialize (IHdf σ v);
- generalize (eval_fully_closed_not_none σ df); intros;
- simpl in H0; cut_to H0; trivial;
- match_option; tauto
- ].
- - Case "DVector"%string.
- apply vectoro_to_ovector_not_none; intros; apply H.
- rewrite vforall_forall in H0; apply H0.
- - Case "DMatrix"%string.
- apply vectoro_to_ovector_not_none; intros.
- apply vectoro_to_ovector_not_none; intros; apply H.
- rewrite vforall_forall in H0; specialize (H0 i).
- rewrite vforall_forall in H0; apply H0.
- - Case "Max"%string.
- destruct H;
- specialize (IHdf1 σ v); specialize (IHdf2 σ v);
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1;
- cut_to H1; trivial;
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2;
- match_option; [|tauto].
- case_eq ( Rle_dec d d0 ); intros.
- cut_to IHdf2; trivial.
- cut_to IHdf1; trivial.
- - Case "VectorApply"%string.
- destruct H.
- specialize (IHdf2 σ v0);
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- cut_to H1; trivial.
- match_option; [|tauto].
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- apply vectoro_to_ovector_not_none; intros.
- match_option.
- specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i) :: nil) (mk_genvar_env v)).
- cut_to IHdf1; trivial.
- tauto.
- - Case "MatrixApply"%string.
- destruct H.
- specialize (IHdf2 σ v0);
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- cut_to H1; trivial.
- match_option; [|tauto].
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- apply vectoro_to_ovector_not_none; intros.
- apply vectoro_to_ovector_not_none; intros.
- specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i i0) :: nil) (mk_genvar_env v)).
- cut_to IHdf1; trivial.
- match_option; tauto.
- - Case "VLossfun"%string.
- destruct H.
- specialize (IHdf2 σ v);
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- cut_to H1; trivial.
- match_option; [|tauto].
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- match_option.
- apply vectoro_to_ovector_not_none in eqq1.
- tauto.
- intros.
- specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i) :: mk_env_entry (v2, DTfloat) (r i) :: nil) (mk_genvar_env v1)).
- cut_to IHdf1; trivial.
- match_option; tauto.
- - Case "MLossfun"%string.
- destruct H.
- specialize (IHdf2 σ v);
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- cut_to H1; trivial.
- match_option; [|tauto].
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- match_option.
- unfold matrixo_to_omatrix in eqq1.
- apply vectoro_to_ovector_not_none in eqq1.
- tauto.
- intros.
- apply vectoro_to_ovector_not_none; intros.
- specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i i0) :: mk_env_entry (v2, DTfloat) (r i i0) :: nil) (mk_genvar_env v1)).
- cut_to IHdf1; trivial.
- match_option; tauto.
- Qed.
-
- Definition scalarMult (T : definition_function_types) (c : float) :=
- match T return
- definition_function_types_interp T -> definition_function_types_interp T with
- | DTfloat => fun f => (c * f)%R
- | DTVector n => fun f => fun i => (c * f i)%R
- | DTMatrix n m => fun f => fun i j => (c * f i j)%R
- end.
-
- Definition dfti_gen_plus {T : definition_function_types} :=
- match T return
- definition_function_types_interp T -> definition_function_types_interp T ->
- definition_function_types_interp T with
- | DTfloat => fun f g => (f + g)%R
- | DTVector n => fun f g => fun i => (f i + g i)%R
- | DTMatrix n m => fun f g => fun i j => (f i j + g i j)%R
- end.
-
- Lemma subvar_addvar_scalar_neq (env : df_env) (oval : float) (s : SubVar) (v: var_type) (grad : definition_function_types_interp (snd v)) :
- let sv := (s, DTfloat) in
- vartlookup env sv = Some oval ->
- v <> sv ->
- subvar sv (vart_update env v (addvar v env grad)) oval = 0%R.
- Proof.
- intros.
- unfold subvar; simpl.
- rewrite lookup_update_neq; trivial.
- rewrite H.
- lra.
- Qed.
-
- Lemma subvar_addvar_scalar_eq (env : df_env) (s : SubVar) (oval grad : float) :
- let v := (s, DTfloat) in
- vartlookup env v = Some oval ->
- subvar v (vart_update env v (addvar v env grad)) oval = grad.
- Proof.
- intros.
- unfold subvar; simpl.
- rewrite lookup_update.
- unfold addvar; simpl.
- rewrite H.
- lra.
- Qed.
-
- Lemma split_subvar (env1 env2: df_env) (oval val1 : float) (s : SubVar) :
- let v := (s, DTfloat) in
- vartlookup env1 v = Some val1 ->
- subvar v env2 oval = (subvar v env2 val1 + subvar v env1 oval)%R.
- Proof.
- intros.
- unfold subvar; simpl.
- rewrite H.
- case_eq (vartlookup env2 v); intros.
- lra.
- lra.
- Qed.
-
- Lemma vsum_nil : vsum vnil = 0%R.
- Proof.
- reflexivity.
- Qed.
-
- Lemma vsum_mult {n} (v : Vector float n) (c : float) :
- (c * vsum v)%R = vsum (fun j => (c * v j)%R).
- Proof.
- unfold vsum, vector_fold_right1, Datatypes.id; simpl.
- induction n; [ | destruct n].
- - repeat rewrite vector_fold_right1_dep_0; lra.
- - repeat rewrite vector_fold_right1_dep_1; lra.
- - repeat rewrite vector_fold_right1_dep_SSn.
- rewrite Rmult_plus_distr_l.
- specialize (IHn (vdrop_last v)); simpl in IHn.
- rewrite IHn.
- f_equal.
- apply vector_fold_right1_dep_ext.
- intros [i pf]; trivial.
- Qed.
-
- Lemma vsum_plus {m:nat} (v1 v2:Vector R m) :
- (vsum v1 + vsum v2)%R = vsum (fun i => (v1 i + v2 i)%R).
- Proof.
- unfold vsum, vector_fold_right1, Datatypes.id; simpl.
- induction m; [ | destruct m].
- - repeat rewrite vector_fold_right1_dep_0; lra.
- - repeat rewrite vector_fold_right1_dep_1; lra.
- - repeat rewrite vector_fold_right1_dep_SSn.
- specialize (IHm (vdrop_last v1) (vdrop_last v2)); simpl in IHm.
- rewrite (Rplus_comm (vlast v2)).
- rewrite (Rplus_assoc (vlast v1)).
- rewrite <- (Rplus_assoc _ _ (vlast v2)).
- rewrite IHm.
- rewrite (Rplus_comm _ (vlast v2)).
- rewrite <- Rplus_assoc.
- f_equal.
- apply vector_fold_right1_dep_ext.
- intros [i pf]; trivial.
- Qed.
-
- Lemma vmap_mult {n} (f: float -> float) (v : Vector float n) (c : float) :
- forall i : {n' : nat | n' < n},
- (c * (vmap f v) i)%R = (vmap (fun x => (c * f x)%R) v) i.
- Proof.
- intros.
- rewrite vmap_nth.
- now rewrite vmap_nth.
- Qed.
-
- Lemma vsum_ext {n} (v v':Vector float n) : vec_eq v v' -> vsum v = vsum v'.
- Proof.
- apply vector_fold_right1_ext.
- Qed.
-
- Lemma msum_ext {m n} (mat mat':Matrix float m n) :
- (forall i j, mat i j = mat' i j) -> msum mat = msum mat'.
- Proof.
- intros.
- apply vsum_ext; intros ?.
- repeat rewrite vmap_nth.
- apply vsum_ext; intros ?; auto.
- Qed.
-
- Lemma msum_mult {m n} (mat : Matrix float m n) (c : float) :
- (c * msum mat)%R = msum (fun i j => (c * mat i j)%R).
- Proof.
- unfold msum.
- rewrite vsum_mult.
- apply vsum_ext; intros i.
- repeat rewrite vmap_nth.
- now rewrite vsum_mult.
- Qed.
-
- Lemma msum_mmap_mult {m n} (mat : Matrix float m n) (c : float) :
- (c * msum mat)%R = msum (mmap (fun x => c * x)%R mat).
- Proof.
- rewrite msum_mult.
- apply msum_ext; intros i j.
- now rewrite mmap_nth.
- Qed.
-
- Lemma msum_mmap_div_denom {m n} (mat : Matrix float m n) (c : float) :
- msum (mmap (fun u : R => (u / c)%R) mat) = (msum mat / c)%R.
- Proof.
- transitivity (msum (mmap (fun u : R => (/ c * u)%R) mat)).
- - apply msum_ext; intros i j.
- repeat rewrite mmap_nth.
- lra.
- - rewrite <- msum_mmap_mult.
- lra.
- Qed.
-
- Lemma vsum0 n : vsum (fun _ : {n' : nat | (n' < n)%nat} => 0%R) = 0%R.
- Proof.
- generalize (vsum_mult (fun _ : {n' : nat | (n' < n)%nat} => 0%R) 0%R); intros HH.
- rewrite Rmult_0_l in HH.
- symmetry.
- simpl in *.
- erewrite vsum_ext; [eassumption | ].
- intro; simpl; lra.
- Qed.
-
- Lemma vsum_unitvector {n} (v:Vector R n) i :
- vsum (fun j => (v j * UnitVector n i j)%R) = v i.
- Proof.
- unfold vsum, vector_fold_right1, Datatypes.id, UnitVector; simpl.
- revert n v i.
- destruct i.
- induction n; [ | destruct n].
- - lia.
- - repeat rewrite vector_fold_right1_dep_1.
- destruct x; [ | lia]; simpl.
- field_simplify.
- now erewrite index_pf_irrel.
- - repeat rewrite vector_fold_right1_dep_SSn.
- unfold vlast, vdrop_last; simpl.
- destruct (equiv_dec (S n) x).
- + ring_simplify.
- simpl.
- destruct e.
- match goal with
- | [|- (_ + ?x)%R = _ ] => replace x with 0%R
- end.
- * ring_simplify.
- now erewrite index_pf_irrel.
- * rewrite <- (vsum0 (S n)) at 1.
- unfold vsum, vector_fold_right1, Fzero, Datatypes.id; simpl.
- apply (@vector_fold_right1_dep_ext (fun _ => R)).
- intros [??].
- destruct (equiv_dec x (S n)).
- -- destruct e.
- lia.
- -- lra.
- + ring_simplify.
- unfold equiv, complement in c.
- assert (pf:x < S n) by lia.
- specialize (IHn (vdrop_last v) pf).
- simpl in IHn.
- erewrite index_pf_irrel.
- rewrite <- IHn.
- apply (@vector_fold_right1_dep_ext (fun _ => R)).
- now intros [??].
- Qed.
-
- Lemma msum_unitmatrix {m n} (v:Matrix R m n) i j :
- msum (fun k l => (v k l * UnitMatrix m n i j k l)%R) = v i j.
- Proof.
- unfold msum.
- unfold UnitMatrix.
- rewrite (vsum_ext _ (
- (fun (k : {n' : nat | n' < m}) => @vsum floatish_R _
- (fun (l : {m' : nat | m' < n}) =>
- (v k l *
- (if equiv_dec (` k) (` i) then if equiv_dec (` l) (` j) then 1%R else 0%R else 0%R))%R))
-
- ))
- by (intros ?; now rewrite vmap_nth).
- rewrite (vsum_ext _ (
- (fun (k : {n' : nat | n' < m}) => (if equiv_dec (` k) (` i) then
- @vsum floatish_R _
- (fun (l : {m' : nat | m' < n}%nat) =>
- ((v k) l *
- if equiv_dec (` l) (` j) then 1%R else 0%R))%R else 0%R))
-
- )).
- - rewrite (vsum_ext _ (
- (fun (k : {n' : nat | n' < m}) => (if equiv_dec (` k) (` i) then
- v k j else 0%R))
-
- )).
- + rewrite (vsum_ext _ (
- (fun (k : {n' : nat | n' < m}) => ((transpose v) j k * @UnitVector floatish_R m i k)%R)
-
- )).
- * now rewrite vsum_unitvector.
- * unfold UnitVector; intros ?; simpl.
- dest_eqdec; unfold transpose; simpl
- ; lra.
- + intros ?.
- dest_eqdec; trivial.
- apply vsum_unitvector.
- - intros ?.
- dest_eqdec; trivial.
- rewrite <- (vsum0 n) at 1.
- apply vsum_ext.
- intros ?; lra.
- Qed.
-
- Ltac vectoro_assert_forall_in H i
- := match type of H with vectoro_to_ovector ?x = Some ?y =>
- assert (forall i, x i = Some (y i)) end.
-
- Lemma vartlookup_list_env_iter {A}
- (s: SubVar)
- (f : A -> df_env -> option df_env)
- (l : list A) (env fenv: df_env):
- list_env_iter f (Some env) l = Some fenv ->
- vartlookup env (s, DTfloat) <> None ->
- (forall (env fenv: df_env) (i:A),
- f i env = Some fenv ->
- vartlookup env (s, DTfloat) <> None ->
- vartlookup fenv (s, DTfloat) <> None) ->
- vartlookup fenv (s, DTfloat) <> None.
- Proof.
- intros.
- revert H0 H.
- generalize env.
- induction l.
- - simpl; intros.
- now invcs H.
- - simpl; intros.
- generalize (list_env_iter_none f l); intros.
- assert (f a env0 <> None); [congruence | ].
- case_eq (f a env0); [|congruence].
- intros.
- apply (IHl d).
- + specialize (H1 env0 d a).
- now apply H1.
- + now rewrite H4 in H.
- Qed.
-
- (*
- Theorem df_eval_deriv_same {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (s:SubVar) :
- let v := (s, DTfloat) in
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl ->
- df_eval_deriv σ df v = df_eval σ (df_deriv df v).
- Proof.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case
- ; simpl in *;trivial; intros
- ; try (destruct H; rewrite IHdf1; trivial; rewrite IHdf2; trivial)
- ; try (rewrite IHdf; trivial; do 2 match_option).
- - Case "DVector"%string.
- f_equal.
- apply functional_extensionality; intros.
- rewrite vforall_forall in H0.
- specialize (H x0); simpl in H.
- specialize (H0 x0).
- apply H; trivial.
- - Case "DMatrix"%string.
- f_equal.
- apply functional_extensionality; intros.
- apply functional_extensionality; intros.
- rewrite vforall_forall in H0.
- specialize (H x0); simpl in H.
- specialize (H0 x0).
- rewrite vforall_forall in H0.
- apply H; trivial.
- - case "Times"%string.
- intros; do 4 match_option.
- - Case "Divide"%string.
- intros; do 4 match_option.
- - Case "Sign"%string.
- match_option.
- generalize (eval_deriv_fully_closed_not_none σ df (s, DTfloat)); tauto.
- - Case "PSign"%string.
- match_option.
- generalize (eval_deriv_fully_closed_not_none σ df (s, DTfloat)); tauto.
- - Case "Max"%string.
- assert (df_eval σ df1 <> None) by (apply eval_fully_closed_not_none; trivial).
- assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial).
- match_option; [|congruence].
- match_option; [|congruence].
- assert (df_eval σ (df_deriv df1 (s, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- apply fully_closed_deriv; trivial.
- assert (df_eval σ (df_deriv df2 (s, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- apply fully_closed_deriv; trivial.
- destruct (df_eval σ (df_deriv df1 (s, DTfloat))); [|congruence].
- destruct (df_eval σ (df_deriv df2 (s, DTfloat))); [|congruence].
- unfold pos_sign; simpl.
- case_eq (Rle_dec d d0); intros; f_equal.
- destruct (Rge_dec (d0 - d) 0); lra.
- destruct (Rge_dec (d0 - d) 0); lra.
- - Case "VectorDot"%string.
- do 4 match_option.
- rewrite vsum_plus.
- do 2 f_equal.
- f_equal.
- apply functional_extensionality; intros.
- lra.
- - Case "MatrixVectorMult"%string.
- do 4 match_option.
- f_equal.
- apply functional_extensionality; intros.
- unfold matrix_vector_mult.
- rewrite vsum_plus; simpl.
- f_equal.
- apply functional_extensionality; intros.
- lra.
- - Case "MatrixMult"%string.
- do 4 match_option.
- f_equal.
- apply functional_extensionality; intros.
- apply functional_extensionality; intros.
- unfold matrix_mult.
- rewrite vsum_plus; simpl; f_equal.
- apply functional_extensionality; intros.
- lra.
- - Case "VectorScalMult"%string.
- do 4 match_option.
- f_equal.
- apply functional_extensionality; intros.
- lra.
- - Case "MatrixScalMult"%string.
- do 4 match_option.
- f_equal.
- apply functional_extensionality; intros.
- apply functional_extensionality; intros.
- lra.
- - Case "VectorApply"%string.
- destruct H.
- assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial).
- match_option; [|congruence].
- rewrite IHdf2; trivial.
- match_option.
- + f_equal.
- apply functional_extensionality; intros.
- assert ( df_eval_deriv [mk_env_entry (v, DTfloat) (d x)] df1 (v, DTfloat) =
- df_eval σ (df_subst (df_deriv df1 (v, DTfloat)) (v, DTfloat)
- (VectorElem () df2 x))).
- XXX
- now rewrite H2.
- + assert ( df_eval σ (df_deriv df2 (s, DTfloat)) <> None ); [|tauto].
- apply eval_fully_closed_not_none.
- apply fully_closed_deriv; trivial.
- - Case "MatrixApply"%string.
- destruct H.
- assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial).
- match_option; [|congruence].
- rewrite IHdf2; trivial.
- match_option.
- + f_equal.
- apply functional_extensionality; intros.
- apply functional_extensionality; intros.
- assert ( df_eval_deriv [mk_env_entry (v, DTfloat) (d x x0)] df1 (v, DTfloat) =
- df_eval σ (df_subst (df_deriv df1 (v, DTfloat)) (v, DTfloat)
- (MatrixElem () df2 x x0))).
- XXX
- now rewrite H2.
- + assert ( df_eval σ (df_deriv df2 (s, DTfloat)) <> None ); [|tauto].
- apply eval_fully_closed_not_none.
- apply fully_closed_deriv; trivial.
- - Case "VLossfun"%string.
- destruct H.
- assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial).
- match_option; [|congruence].
- rewrite IHdf2; trivial.
- match_option.
- do 2 match_option.
- + do 2 f_equal.
- apply functional_extensionality; intros.
- XXX
- + generalize (vectoro_to_ovector_exists_None eqq2).
- intros; destruct H2.
- assert (df_eval
- σ
- (df_subst (df_subst (df_deriv df1 (v1, DTfloat)) (v1, DTfloat) (VectorElem () df2 x))
- (v2, DTfloat) (Number () (r x))) <> None); [|tauto].
- apply eval_fully_closed_not_none.
- apply fully_closed_subst; simpl; [|trivial].
- apply fully_closed_subst; simpl.
- * apply fully_closed_deriv.
- now apply fully_closed_over_pair.
- * now apply fully_closed_over_cons.
- + generalize (vectoro_to_ovector_exists_None eqq1).
- intros; destruct H2.
- match_option_in e.
- assert (df_eval_deriv [mk_env_entry (v1, DTfloat) (d x); mk_env_entry (v2, DTfloat) (r x)]
- df1 (v1, DTfloat) <> None); [|tauto].
- apply eval_deriv_fully_closed_not_none; trivial.
- - Case "MLossfun"%string.
- destruct H.
- assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial).
- match_option; [|congruence].
- rewrite IHdf2; trivial.
- match_option.
- do 2 match_option.
- + do 2 f_equal.
- XXX
- + generalize (vectoro_to_ovector_exists_None eqq2); intros; destruct H2.
- generalize (vectoro_to_ovector_exists_None e); intros; destruct H2.
- match_option_in e0.
- match_option_in eqq3.
- assert (df_eval σ
- (df_subst
- (df_subst (df_deriv df1 (v1, DTfloat)) (v1, DTfloat) (MatrixElem () df2 x x0))
- (v2, DTfloat) (Number () (r x x0))) <> None); [|tauto].
- apply eval_fully_closed_not_none.
- apply fully_closed_subst; simpl; [|trivial].
- apply fully_closed_subst; simpl.
- * apply fully_closed_deriv.
- now apply fully_closed_over_pair.
- * now apply fully_closed_over_cons.
- + generalize (vectoro_to_ovector_exists_None eqq1).
- intros; destruct H2.
- generalize (vectoro_to_ovector_exists_None e).
- intros; destruct H2.
- match_option_in e0.
- assert (df_eval_deriv [mk_env_entry (v1, DTfloat) (d x x0);
- mk_env_entry (v2, DTfloat) (r x x0)]
- df1 (v1, DTfloat) <> None); [|tauto].
- apply eval_deriv_fully_closed_not_none; trivial.
- *)
-
- Theorem df_eval_deriv_scalar_same (σ:df_env) (df:DefinedFunction UnitAnn DTfloat) (s:SubVar) :
- let v := (s, DTfloat) in
- let vl := map (fun ve => projT1 ve) σ in
- is_scalar_function df ->
- fully_closed_over df vl ->
- df_eval_deriv σ df v = df_eval σ (df_deriv df v).
- Proof.
- simpl.
- intros is_scalar.
- generalize is_scalar.
- pattern df.
- revert df is_scalar.
- DefinedFunction_scalar_cases (apply is_scalar_function_ind) Case
- ; simpl; trivial; intros
- ; try (destruct H1; destruct is_scalar;
- specialize (H H3 H1); specialize (H0 H4 H2); now rewrite H, H0)
- ; try (rewrite H; trivial; do 2 match_option).
- - Case "Times"%string.
- destruct H1.
- destruct is_scalar.
- cut_to H; trivial.
- cut_to H0; trivial.
- rewrite H, H0.
- do 4 match_option.
- - Case "Divide"%string.
- destruct H1.
- destruct is_scalar.
- cut_to H; trivial.
- cut_to H0; trivial.
- rewrite H, H0.
- do 4 match_option.
- - Case "Sign"%string.
- match_option.
- assert ( df_eval_deriv σ e (s, DTfloat) <> None); [|tauto].
- apply eval_deriv_fully_closed_not_none; trivial.
- - Case "PSign"%string.
- match_option.
- assert ( df_eval_deriv σ e (s, DTfloat) <> None); [|tauto].
- apply eval_deriv_fully_closed_not_none; trivial.
- - Case "Max"%string.
- destruct is_scalar; destruct H1.
- cut_to H; trivial.
- cut_to H0; trivial.
- rewrite H, H0.
- assert (df_eval σ l <> None) by (apply eval_fully_closed_not_none; trivial).
- assert (df_eval σ r <> None) by (apply eval_fully_closed_not_none; trivial).
- match_option; [|congruence].
- match_option; [|congruence].
- assert (df_eval σ (df_deriv l (s, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- apply fully_closed_deriv; trivial.
- assert (df_eval σ (df_deriv r (s, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- apply fully_closed_deriv; trivial.
- destruct (df_eval σ (df_deriv l (s, DTfloat))); [|congruence].
- destruct (df_eval σ (df_deriv r (s, DTfloat))); [|congruence].
- unfold pos_sign; simpl.
- case_eq (Rle_dec d d0); intros; f_equal.
- destruct (Rge_dec (d0 - d) 0); lra.
- destruct (Rge_dec (d0 - d) 0); lra.
- Qed.
-
- Lemma vartlookup_list_env_iter2 {A}
- (s: SubVar)
- {f : A -> df_env -> option df_env}
- {l : list A} {env fenv: df_env}:
- list_env_iter f (Some env) l = Some fenv ->
- vartlookup env (s, DTfloat) <> None ->
- (forall (env fenv: df_env) (i:A),
- f i env = Some fenv ->
- vartlookup env (s, DTfloat) <> None ->
- vartlookup fenv (s, DTfloat) <> None) ->
- vartlookup fenv (s, DTfloat) <> None.
- Proof.
- apply (vartlookup_list_env_iter s f l env fenv).
- Qed.
-
- Lemma scalarMult_list_env_iter
- (s: SubVar) (c val1 val2:float) (A :Type)
- (f g : A -> df_env -> option df_env)
- (l : list A) (env1 env2 fenv1 fenv2: df_env):
- list_env_iter f (Some env1) l = Some fenv1 ->
- list_env_iter g (Some env2) l = Some fenv2 ->
- vartlookup env1 (s, DTfloat) = Some val1 ->
- vartlookup env2 (s, DTfloat) = Some val2 ->
- (forall (i:A) (env1 env2 fenv1 fenv2: df_env) (v1 v2: float),
- vartlookup env1 (s, DTfloat) = Some v1 ->
- vartlookup env2 (s, DTfloat) = Some v2 ->
- f i env1 = Some fenv1 -> g i env2 = Some fenv2 ->
- subvar (s, DTfloat) fenv1 v1 = (c * subvar (s, DTfloat) fenv2 v2)%R) ->
- (forall (env fenv: df_env) (i:A),
- f i env = Some fenv ->
- vartlookup env (s, DTfloat) <> None ->
- vartlookup fenv (s, DTfloat) <> None) ->
- (forall (env fenv: df_env) (i:A),
- g i env = Some fenv ->
- vartlookup env (s, DTfloat) <> None ->
- vartlookup fenv (s, DTfloat) <> None) ->
- subvar (s, DTfloat) fenv1 val1 = (c * subvar (s, DTfloat) fenv2 val2)%R.
- Proof.
- intros.
- generalize (vartlookup_list_env_iter s f l env1 fenv1); intros.
- assert (vartlookup env1 (s, DTfloat) <> None).
- rewrite H1; discriminate.
- assert (vartlookup env2 (s, DTfloat) <> None).
- rewrite H2; discriminate.
- specialize (H6 H H7 H4).
- generalize (vartlookup_list_env_iter s g l env2 fenv2); intros.
- specialize (H9 H0 H8 H5).
- revert H1 H2 H H0.
- generalize env1 env2 val1 val2.
- induction l.
- - intros.
- unfold subvar; simpl.
- unfold list_env_iter in H; simpl in H.
- unfold list_env_iter in H0; simpl in H0.
- invcs H; invcs H0.
- rewrite H1; rewrite H2; lra.
- - simpl; intros.
- generalize (list_env_iter_none f l); intros.
- assert (f a env0 <> None); [congruence | ].
- case_eq (f a env0); [intros|congruence].
- generalize (list_env_iter_none g l); intros.
- assert (g a env3 <> None); [congruence | ].
- case_eq (g a env3); [intros|congruence].
- assert (vartlookup d (s, DTfloat) <> None).
- apply (H4 env0 d a); trivial; congruence.
- assert (vartlookup d0 (s, DTfloat) <> None).
- apply (H5 env3 d0 a); trivial; congruence.
- case_eq (vartlookup d (s, DTfloat)); [intros | tauto].
- case_eq (vartlookup d0 (s, DTfloat)); [intros | tauto].
- specialize (IHl d d0 d1 d2).
- specialize (H3 a env0 env3 d d0 val0 val3).
- rewrite (split_subvar d fenv1 val0 d1); trivial.
- rewrite (split_subvar d0 fenv2 val3 d2); trivial.
- specialize (H3 H1 H2 H12 H15).
- rewrite H12 in H.
- rewrite H15 in H0.
- specialize (IHl H18 H19 H H0).
- lra.
- Qed.
-
- Lemma list_env_iter_subvar_env2
- (s: SubVar) (val1 val2:float) (A :Type)
- (f g : A -> df_env -> option df_env)
- (l : list A) (env1 env2 fenv1 fenv2: df_env):
- list_env_iter f (Some env1) l = Some fenv1 ->
- list_env_iter g (Some env2) l = Some fenv2 ->
- vartlookup env1 (s, DTfloat) = Some val1 ->
- vartlookup env2 (s, DTfloat) = Some val2 ->
- (forall (i:A) (env1 env2 fenv1 fenv2: df_env) (v1 v2: float),
- vartlookup env1 (s, DTfloat) = Some v1 ->
- vartlookup env2 (s, DTfloat) = Some v2 ->
- f i env1 = Some fenv1 -> g i env2 = Some fenv2 ->
- subvar (s, DTfloat) fenv1 v1 = subvar (s, DTfloat) fenv2 v2) ->
- (forall (env fenv: df_env) (i:A),
- f i env = Some fenv ->
- vartlookup env (s, DTfloat) <> None ->
- vartlookup fenv (s, DTfloat) <> None) ->
- (forall (env fenv: df_env) (i:A),
- g i env = Some fenv ->
- vartlookup env (s, DTfloat) <> None ->
- vartlookup fenv (s, DTfloat) <> None) ->
- subvar (s, DTfloat) fenv1 val1 = subvar (s, DTfloat) fenv2 val2.
- Proof.
- intros.
- generalize (scalarMult_list_env_iter s 1%R val1 val2 A f g l env1 env2 fenv1 fenv2).
- intros.
- specialize (H6 H H0 H1 H2).
- cut_to H6.
- now replace (1 * subvar (s, DTfloat) fenv2 val2)%R with (subvar (s, DTfloat) fenv2 val2) in H6 by lra.
- intros.
- replace (1 * subvar (s, DTfloat) fenv3 v2)%R with (subvar (s, DTfloat) fenv3 v2) by lra.
- apply (H3 i env0 env3 fenv0 fenv3); trivial.
- apply H4.
- apply H5.
- Qed.
-
- Lemma scalarMult_backprop_grad_scalar {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (s: SubVar) (grad_env1 grad_env2:df_env) (grad : definition_function_types_interp T) (c:float) :
- let v := (s, DTfloat) in
- vartlookup grad_env1 v <> None -> vartlookup grad_env2 v <> None ->
- df_eval_backprop_deriv σ df grad_env1 (scalarMult T c grad) <> None ->
- df_eval_backprop_deriv σ df grad_env2 grad <> None ->
- df_eval_backprop_delta σ df v grad_env1 (scalarMult T c grad) =
- lift (fun e => scalarMult (snd v) c e) (df_eval_backprop_delta σ df v grad_env2 grad).
- Proof.
- revert grad_env1 grad_env2.
- unfold df_eval_backprop_delta.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case
- ; simpl; intros grad_env1 grad_env2 neq1 neq2; intros.
- - Case "Number"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [|tauto].
- intros; simpl; f_equal.
- unfold subvar; simpl.
- match_destr; match_destr.
- inversion H1; subst.
- inversion H2; subst.
- lra.
- - Case "Constant"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- intros; simpl; f_equal.
- unfold subvar; simpl.
- match_destr; match_destr.
- inversion H1; subst.
- inversion H2; subst.
- lra.
- - Case "DVector"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- intros; simpl.
- unfold lift.
- match_option; [|tauto].
- case_eq (two_vector_env_iter_alt
- (fun (x0 : DefinedFunction Ann DTfloat) (g : R) (genv : df_env) =>
- df_eval_backprop_deriv σ x0 genv g) grad_env2 x grad); [|tauto].
- unfold two_vector_env_iter_alt in *.
- intros; f_equal.
- apply (scalarMult_list_env_iter
- s c d0 d {n' : nat | n' < n}
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (c * grad i)%R)
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (grad i))
- (bounded_seq0 n) grad_env1 grad_env2); trivial.
- + intros.
- specialize (H i (grad i) env1 env2).
- assert (vartlookup env1 (s, DTfloat) <> None); [congruence|].
- assert (vartlookup env2 (s, DTfloat) <> None); [congruence|].
- specialize (H H9 H10).
- assert (df_eval_backprop_deriv σ (x i) env1 (c * grad i)%R <> None); [congruence|].
- assert (df_eval_backprop_deriv σ (x i) env2 (grad i) <> None); [congruence|].
- specialize (H H11 H12).
- unfold lift in H; simpl in H.
- rewrite H5, H6, H7, H8 in H.
- now inversion H.
- + intros.
- now apply (df_eval_backprop_deriv_preserves_lookup_not_none H5).
- + intros.
- now apply (df_eval_backprop_deriv_preserves_lookup_not_none H5).
- - Case "DMatrix"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- intros; simpl.
- unfold lift.
- match_option; [|tauto].
- case_eq (two_matrix_env_iter_alt
- (fun (x0 : DefinedFunction Ann DTfloat) (g : R) (genv : df_env) =>
- df_eval_backprop_deriv σ x0 genv g) grad_env2 x grad); [|tauto].
- intros; f_equal.
- unfold two_matrix_env_iter_alt in *.
- apply (scalarMult_list_env_iter
- s c d0 d {n' : nat | n' < n}
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (c * grad i j)%R)
- (Some env) (bounded_seq0 m))
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j)) (Some env)
- (bounded_seq0 m))
- (bounded_seq0 n) grad_env1 grad_env2); trivial.
- + intros.
- apply (scalarMult_list_env_iter
- s c v1 v2 {m' : nat | m' < m}
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (c * grad i j)%R)
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (bounded_seq0 m) env1 env2); trivial.
- * intros.
- specialize (H i i0 (grad i i0) env0 env3).
- assert (vartlookup env0 (s, DTfloat) <> None); [congruence|].
- assert (vartlookup env3 (s, DTfloat) <> None); [congruence|].
- specialize (H H13 H14).
- assert (df_eval_backprop_deriv σ (x i i0) env0 (c * grad i i0)%R <> None); [congruence|].
- assert (df_eval_backprop_deriv σ (x i i0) env3 (grad i i0) <> None); [congruence|].
- specialize (H H15 H16).
- unfold lift in H; simpl in H.
- rewrite H9, H10, H11, H12 in H.
- now inversion H.
- * intros.
- now apply (df_eval_backprop_deriv_preserves_lookup_not_none H9).
- * intros.
- now apply (df_eval_backprop_deriv_preserves_lookup_not_none H9).
- + intros.
- apply (vartlookup_list_env_iter
- s
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (c * grad i j)%R)
- (bounded_seq0 m) env fenv); trivial; intros.
- now apply (df_eval_backprop_deriv_preserves_lookup_not_none H7).
- + intros.
- apply (vartlookup_list_env_iter
- s
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (bounded_seq0 m) env fenv); trivial; intros.
- now apply (df_eval_backprop_deriv_preserves_lookup_not_none H7).
- - Case "Var"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- intros.
- destruct (vart_dec v (s, DTfloat)).
- + subst; simpl.
- rewrite H2; simpl.
- rewrite subvar_addvar_scalar_eq; trivial.
- rewrite H1; simpl.
- now rewrite subvar_addvar_scalar_eq.
- + case_eq (vartlookup grad_env1 v); intros; simpl.
- * case_eq (vartlookup grad_env2 v); intros; simpl; f_equal.
- -- rewrite subvar_addvar_scalar_neq; trivial.
- rewrite subvar_addvar_scalar_neq; trivial.
- lra.
- -- rewrite subvar_addvar_scalar_neq; trivial.
- unfold subvar; simpl.
- rewrite H1.
- lra.
- * case_eq (vartlookup grad_env2 v); intros; simpl; f_equal.
- -- rewrite subvar_addvar_scalar_neq; trivial.
- unfold subvar; simpl.
- rewrite H2.
- lra.
- -- unfold subvar; simpl.
- rewrite H2; rewrite H1.
- lra.
- - Case "Plus"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- specialize (IHdf1 grad grad_env1 grad_env2); intros.
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1 (c * grad)%R); intros.
- rewrite H3 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros.
- rewrite H4 in H0; simpl in H0.
- + specialize (IHdf2 grad d1 d2).
- rewrite H3 in IHdf1; rewrite H4 in IHdf1.
- assert (vartlookup d1 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1).
- case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2).
- case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros.
- rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d1 (c * grad)%R); [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d2 grad); [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d1 d5 d0 d3) by trivial.
- rewrite (split_subvar d2 d6 d d4) by trivial.
- rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H11; inversion H12.
- rewrite H14; rewrite H15; lra.
- + now rewrite H4 in H0.
- + now rewrite H3 in H.
- - Case "Minus"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- specialize (IHdf1 grad grad_env1 grad_env2); intros.
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1 (c * grad)%R); intros.
- rewrite H3 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros.
- rewrite H4 in H0; simpl in H0.
- + specialize (IHdf2 (-grad)%R d1 d2).
- rewrite H3 in IHdf1; rewrite H4 in IHdf1.
- assert (vartlookup d1 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1).
- case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2).
- case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros.
- rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d1 (-(c * grad))%R); [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d2 (-grad)%R); [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d1 d5 d0 d3) by trivial.
- rewrite (split_subvar d2 d6 d d4) by trivial.
- replace (c * -grad)%R with (-(c*grad))%R in IHdf2 by lra.
- rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H11; inversion H12.
- rewrite H14; rewrite H15; lra.
- + now rewrite H4 in H0.
- + now rewrite H3 in H.
- - Case "Times"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- case_eq ( df_eval σ df1); [|tauto].
- case_eq ( df_eval σ df2); [|tauto]; intros.
- specialize (IHdf1 (d * grad)%R grad_env1 grad_env2).
- rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1 (d * (c * grad))%R); intros.
- rewrite H1, H2, H5 in H; simpl in H.
- rewrite H1, H2 in H0; simpl in H0.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 (d * grad)%R); intros.
- rewrite H6 in H0; simpl in H0.
- + specialize (IHdf2 (d0 * grad)%R d3 d4).
- replace (c * (d * grad))%R with (d * (c * grad))%R in IHdf1 by lra.
- rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1.
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1).
- case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2).
- case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros.
- rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d3 (d0 * (c * grad))%R); [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d4 (d0 * grad)%R); [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d3 d7 d2 d5) by trivial.
- rewrite (split_subvar d4 d8 d1 d6) by trivial.
- replace (c * (d0 * grad))%R with (d0 * (c*grad))%R in IHdf2 by lra.
- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d3 d2) = Some (c * subvar (s, DTfloat) d4 d1)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d7 d5) = Some (c * subvar (s, DTfloat) d8 d6)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H13; inversion H14.
- rewrite H16; rewrite H17; lra.
- + now rewrite H6 in H0.
- + rewrite H2 in H; rewrite H1 in H.
- now rewrite H5 in H.
- - Case "Divide"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- case_eq ( df_eval σ df1); [|tauto].
- case_eq ( df_eval σ df2); [|tauto]; intros.
- specialize (IHdf1 (grad / d)%R grad_env1 grad_env2).
- rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1 (c * grad / d)%R); intros.
- rewrite H1 in H; rewrite H2 in H; simpl in H.
- rewrite H1 in H0; rewrite H2 in H0; simpl in H0.
- rewrite H5 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 (grad / d)%R); intros.
- rewrite H6 in H0; simpl in H0.
- + specialize (IHdf2 (- d0 / (d*d) * grad)%R d3 d4).
- replace (c * (grad / d))%R with (c * grad / d )%R in IHdf1 by lra.
- rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1.
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1).
- case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2).
- case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros.
- rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d3 (- d0 /(d * d) * (c * grad))%R); [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d4 (- d0 / (d * d) * grad)%R); [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d3 d7 d2 d5) by trivial.
- rewrite (split_subvar d4 d8 d1 d6) by trivial.
- replace (c * (-d0 / (d * d) * grad))%R with (- d0/(d * d) * (c*grad))%R in IHdf2 by lra.
- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d3 d2) = Some (c * subvar (s, DTfloat) d4 d1)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d7 d5) = Some (c * subvar (s, DTfloat) d8 d6)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H13; inversion H14.
- rewrite H16; rewrite H17; lra.
- + now rewrite H6 in H0.
- + rewrite H2 in H; rewrite H1 in H.
- now rewrite H5 in H.
- - Case "Square"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- case_eq (df_eval σ df); [ | tauto]; intros.
- specialize (IHdf (2 * d1 * grad)%R grad_env1 grad_env2); simpl in *.
- rewrite H1 in IHdf; rewrite H2 in IHdf.
- replace (2 * d1 * (c * grad))%R with (c * (2 * d1 * grad))%R by lra.
- rewrite H3 in H; rewrite H3 in H0; simpl in *.
- replace (2 * d1 * (c * grad))%R with (c * (2 * d1 * grad))%R in H by lra.
- now apply IHdf.
- - Case "Exp"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- case_eq (df_eval σ df); [ | tauto]; intros.
- specialize (IHdf (grad * exp d1)%R grad_env1 grad_env2); simpl in *.
- replace (c * grad * exp d1)%R with (c * (grad * exp d1))%R by lra.
- rewrite H1 in IHdf; rewrite H2 in IHdf.
- rewrite H3 in H; rewrite H3 in H0.
- replace (c * grad * exp d1)%R with (c * (grad * exp d1))%R in H by lra.
- now apply IHdf.
- - Case "Log"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- case_eq (df_eval σ df); [ | tauto]; intros.
- specialize (IHdf (grad / d1)%R grad_env1 grad_env2 ); simpl in *.
- replace (c * grad / d1)%R with (c * (grad / d1))%R by lra.
- rewrite H1 in IHdf; rewrite H2 in IHdf.
- rewrite H3 in H; rewrite H3 in H0.
- replace (c * grad / d1)%R with (c * (grad / d1))%R in H by lra.
- now apply IHdf.
- - Case "Abs"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- case_eq (df_eval σ df); [ | tauto]; intros.
- specialize (IHdf (grad * sign d1)%R grad_env1 grad_env2); simpl in *.
- replace (c * grad * (@sign floatish_R d1))%R
- with (c * (grad * (@sign floatish_R d1)))%R by lra.
- rewrite H1 in IHdf; rewrite H2 in IHdf.
- rewrite H3 in H; rewrite H3 in H0.
- replace (c * grad * (@sign floatish_R d1))%R
- with (c * (grad * (@sign floatish_R d1)))%R in H by lra.
- now apply IHdf.
- - Case "Sign"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- specialize (IHdf (0)%R grad_env1 grad_env2); simpl in *.
- replace (0%R) with (c * 0)%R at 1 by lra.
- rewrite H1 in IHdf; rewrite H2 in IHdf.
- replace (0%R) with (c * 0)%R in H by lra.
- now apply IHdf.
- - Case "PSign"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- specialize (IHdf (0)%R grad_env1 grad_env2); simpl in *.
- replace (0%R) with (c * 0)%R at 1 by lra.
- rewrite H1 in IHdf; rewrite H2 in IHdf.
- replace (0%R) with (c * 0)%R in H by lra.
- now apply IHdf.
- - Case "Max"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- case_eq (df_eval σ df1); [ | tauto]; intros.
- case_eq (df_eval σ df2); [ | tauto]; intros.
- rewrite H3 in H; rewrite H3 in H0.
- rewrite H4 in H; rewrite H4 in H0.
- case_eq (Rle_dec d1 d2); intros.
- + specialize (IHdf2 grad grad_env1 grad_env2); simpl in *.
- rewrite H1 in IHdf2; rewrite H2 in IHdf2.
- rewrite H5 in H; rewrite H5 in H0.
- now apply IHdf2.
- + specialize (IHdf1 grad grad_env1 grad_env2); simpl in *.
- rewrite H1 in IHdf1; rewrite H2 in IHdf1.
- rewrite H5 in H; rewrite H5 in H0.
- now apply IHdf1.
- - Case "VectorDot"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- case_eq (df_eval σ df1); [ | tauto]; intros.
- case_eq (df_eval σ df2); [ | tauto]; intros.
- specialize (IHdf1 (vmap (fun rv => (rv * grad)%R) d2) grad_env1 grad_env2).
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1
- (vmap (fun rv : R => (rv * (c * grad))%R) d2)); intros.
- rewrite H3 in H; rewrite H4 in H; rewrite H5 in H; simpl in H.
- rewrite H3 in H0; rewrite H4 in H0; simpl in H0.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2
- (vmap (fun rv : R => (rv * grad)%R) d2)); intros.
- rewrite H6 in H0; simpl in H0.
- + specialize (IHdf2 (vmap (fun lv => (lv *grad)%R) d1) d3 d4).
- replace (fun i => (c * vmap (fun rv : R => rv * grad) d2 i)%R) with
- (vmap (fun rv : R => (rv * (c * grad))%R) d2) in IHdf1.
- rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1.
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1).
- case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2).
- case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros.
- rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d3
- (vmap (fun lv : R => (lv * (c * grad))%R) d1))
- ; [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d4 (vmap (fun lv : R => (lv * grad)%R) d1))
- ; [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d3 d7 d d5) by trivial.
- rewrite (split_subvar d4 d8 d0 d6) by trivial.
- replace
- (fun i : {n' : nat | n' < n} => (c * vmap (fun lv : R => lv * grad) d1 i)%R) with
- (vmap (fun lv : R => (lv * (c * grad))%R) d1) in IHdf2.
- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d3 d) = Some (c * subvar (s, DTfloat) d4 d0)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d7 d5) = Some (c * subvar (s, DTfloat) d8 d6)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H13; inversion H14.
- rewrite H16; rewrite H17; lra.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vmap_mult.
- assert ((fun lv => (lv * (c * grad))%R) = (fun x0 => (c * (x0 * grad))%R)).
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- now rewrite H13.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vmap_mult.
- assert ((fun rv => (rv * (c * grad))%R) = (fun x0 => (c * (x0 * grad))%R)).
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- now rewrite H7.
- + now rewrite H6 in H0.
- + rewrite H3 in H; rewrite H4 in H.
- now rewrite H5 in H.
- - Case "VectorSum"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- specialize (IHdf (ConstVector n grad) grad_env1 grad_env2).
- rewrite H1 in IHdf; rewrite H2 in IHdf; simpl in IHdf.
- replace (ConstVector n (c * grad)%R) with
- (fun i => (c * ConstVector n grad i)%R).
- now apply IHdf.
- now unfold ConstVector.
- - Case "MatrixSum"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- specialize (IHdf (ConstMatrix m n grad) grad_env1 grad_env2).
- rewrite H1 in IHdf; rewrite H2 in IHdf; simpl in IHdf.
- replace (ConstMatrix m n (c * grad)%R) with
- (fun i j => (c * ConstMatrix m n grad i j)%R).
- now apply IHdf.
- now unfold ConstMatrix.
- - Case "VectorElem"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- specialize (IHdf (fun k =>
- if equiv_dec (` k) (` i) then grad else 0%R) grad_env1 grad_env2).
- rewrite H1 in IHdf; rewrite H2 in IHdf; simpl in *.
- replace (fun i0 => (c * (if equiv_dec (` i0) (` i) then grad else 0))%R) with
- (fun k : {n' : nat | n' < n} =>
- if equiv_dec (` k) (` i) then (c * grad)%R else 0%R) in IHdf.
- now rewrite IHdf.
- apply FunctionalExtensionality.functional_extensionality; intros.
- destruct (equiv_dec (` x) (` i)); lra.
- - Case "MatrixElem"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- specialize (IHdf (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) =>
- if equiv_dec (` k1) (` i)
- then if equiv_dec (` k2) (` j) then grad else 0%R else 0%R)
- grad_env1 grad_env2).
- rewrite H1 in IHdf; rewrite H2 in IHdf; simpl in *.
- replace (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) =>
- if equiv_dec (` k1) (` i)
- then if equiv_dec (` k2) (` j) then (c * grad)%R else 0%R else 0%R) with
- (fun (i0 : {n' : nat | n' < m}) (j0 : {m' : nat | m' < n}) =>
- (c *
- (if equiv_dec (` i0) (` i)
- then if equiv_dec (` j0) (` j) then grad else 0
- else 0))%R) in *.
- + now rewrite IHdf.
- + apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- destruct (equiv_dec (` x) (` i)); [|lra].
- destruct (equiv_dec (` x0) (` j)); lra.
- - Case "MatrixVectorMult"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- case_eq ( df_eval σ df1); [|tauto].
- case_eq ( df_eval σ df2); [|tauto]; intros.
- specialize (IHdf1 (fun i j => (grad i * d j)%R) grad_env1 grad_env2).
- rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1
- (fun i j => (c * grad i * d j)%R)); intros.
- rewrite H1 in H; rewrite H2 in H; simpl in H.
- rewrite H1 in H0; rewrite H2 in H0; simpl in H0.
- rewrite H5 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2
- (fun i j => (grad i * d j)%R)); intros.
- rewrite H6 in H0; simpl in H0.
- + specialize (IHdf2 (matrix_vector_mult (fun i j => d0 j i) grad) d3 d4).
- replace
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (c * (grad i * d j))%R) with
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (c * grad i * d j)%R) in IHdf1.
- * rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1.
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1).
- case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2).
- case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros.
- rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- unfold lift; match_case; [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d4
- (matrix_vector_mult (fun i j => d0 j i) grad))
- ; [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d3 d7 d2 d5) by trivial.
- rewrite (split_subvar d4 d8 d1 d6) by trivial.
- replace (fun i : {n' : nat | n' < n} =>
- (c *
- (@matrix_vector_mult floatish_R _ _
- (fun (i0 : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) =>
- d0 j i0) grad) i)%R) with
- (matrix_vector_mult
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => d0 j i)
- (fun i : {n' : nat | n' < m} => (c * grad i)%R)) in IHdf2.
- -- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d3 d2) =
- Some (c * subvar (s, DTfloat) d4 d1)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d7 d5) =
- Some (c * subvar (s, DTfloat) d8 d6)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H13; inversion H14.
- rewrite H16; rewrite H17; lra.
- -- unfold matrix_vector_mult.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vsum_mult; f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- simpl; lra.
- * apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- + now rewrite H6 in H0.
- + rewrite H2 in H; rewrite H1 in H.
- now rewrite H5 in H.
- - Case "MatrixVectorAdd"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- specialize (IHdf1 grad grad_env1 grad_env2); intros.
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i j => (c * grad i j)%R))
- ; intros.
- + rewrite H3 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros.
- * rewrite H4 in H0; simpl in H0.
- match_option_in H0; [|tauto].
- match_option_in H; [|tauto].
- rewrite H3 in IHdf1.
- unfold lift.
- f_equal.
- rewrite H4 in IHdf1.
- unfold lift in IHdf1.
- assert (vartlookup d1 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1).
- case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2).
- case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros.
- rewrite (split_subvar d1 d4 d0 d5) by trivial.
- rewrite (split_subvar d2 d3 d d6) by trivial.
- assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R)
- by (apply IHdf1; trivial; discriminate).
- inversion H9; rewrite H11.
- assert (Some (subvar (s, DTfloat) d4 d5) = Some (c * subvar (s, DTfloat) d3 d6)%R).
- -- f_equal.
- apply (scalarMult_list_env_iter
- s c d5 d6 {m' : nat | m' < n}
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv
- σ df2 env
- (transpose
- (fun (i0 : {n' : nat | n' < m})
- (j : {m' : nat | m' < n}) =>
- (c * grad i0 j)%R) i))
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad i))
- (bounded_seq0 n) d1 d2 d4 d3); trivial.
- ++ intros.
- assert (vartlookup env1 (s, DTfloat) <> None); [congruence|].
- assert (vartlookup env2 (s, DTfloat) <> None); [congruence|].
- assert (df_eval_backprop_deriv
- σ df2 env1
- (transpose
- (fun (i0 : {n' : nat | n' < m})
- (j : {m' : nat | m' < n}) => (c * grad i0 j)%R) i) <> None)
- ; [congruence|].
- assert (df_eval_backprop_deriv σ df2 env2 (transpose grad i) <> None)
- ;[congruence|].
- specialize (IHdf2 (transpose grad i) env1 env2).
- specialize (IHdf2 H15 H16).
- specialize (IHdf2 H17 H18).
- unfold lift in IHdf2; simpl in IHdf2.
- rewrite H10, H12, H14 in IHdf2.
- unfold transpose in IHdf2; unfold transpose in H13.
- rewrite H13 in IHdf2; now inversion IHdf2.
- ++ intros.
- now apply (df_eval_backprop_deriv_preserves_lookup_not_none H10).
- ++ intros.
- now apply (df_eval_backprop_deriv_preserves_lookup_not_none H10).
- -- inversion H10; rewrite H13; lra.
- * rewrite H4 in H0; tauto.
- + rewrite H3 in H; tauto.
- - Case "MatrixMult"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- case_eq ( df_eval σ df1); [|tauto].
- case_eq ( df_eval σ df2); [|tauto]; intros.
- specialize (IHdf1 (matrix_mult grad (fun i j => d j i)) grad_env1 grad_env2).
- rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1
- (matrix_mult (fun i j => (c * grad i j)%R)
- (fun i j => d j i))); intros.
- rewrite H1 in H; rewrite H2 in H; simpl in H.
- rewrite H1 in H0; rewrite H2 in H0; simpl in H0.
- rewrite H5 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2
- (matrix_mult grad (fun i j => d j i))); intros.
- rewrite H6 in H0; simpl in H0.
- + specialize (IHdf2 (matrix_mult (fun i j => d0 j i) grad) d3 d4).
- replace (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < p}) =>
- (c *
- (@matrix_mult floatish_R m n p grad
- (fun (i0 : {n' : nat | (n' < n)%nat}) (j0 : {m' : nat | (m' < p)%nat}) =>
- d j0 i0)) i j)%R) with
- (matrix_mult (fun i j => (c * grad i j)%R)
- (fun i j => d j i)) in IHdf1.
- * rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1.
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1).
- case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2).
- case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros.
- rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- unfold lift; match_case; [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d4 (matrix_mult (fun i j => d0 j i) grad))
- ; [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d3 d7 d2 d5) by trivial.
- rewrite (split_subvar d4 d8 d1 d6) by trivial.
- replace (fun (i : {n' : nat | n' < p}) (j : {m' : nat | m' < n}) =>
- (c *
- (@matrix_mult floatish_R p m n
- (fun (i0 : {n' : nat | (n' < p)%nat})
- (j0 : {m' : nat | (m' < m)%nat}) =>
- d0 j0 i0) grad) i j)%R) with
- (matrix_mult (fun (i : {n' : nat | n' < p}) (j : {m' : nat | m' < m}) => d0 j i)
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (c * grad i j)%R)) in IHdf2.
- -- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d3 d2) =
- Some (c * subvar (s, DTfloat) d4 d1)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d7 d5) =
- Some (c * subvar (s, DTfloat) d8 d6)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H13; inversion H14.
- rewrite H16; rewrite H17; lra.
- -- unfold matrix_mult.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vsum_mult; f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- simpl; lra.
- * unfold matrix_mult.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vsum_mult; f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- simpl; lra.
- + now rewrite H6 in H0.
- + rewrite H2 in H; rewrite H1 in H.
- now rewrite H5 in H.
- - Case "VectorPlus"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- specialize (IHdf1 grad grad_env1 grad_env2); intros.
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i => (c * grad i)%R)); intros.
- rewrite H3 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros.
- rewrite H4 in H0; simpl in H0.
- + specialize (IHdf2 grad d1 d2).
- rewrite H3 in IHdf1; rewrite H4 in IHdf1.
- assert (vartlookup d1 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1).
- case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2).
- case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros.
- rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d1 (fun i => (c * grad i)%R))
- ; [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d2 grad); [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d1 d5 d0 d3) by trivial.
- rewrite (split_subvar d2 d6 d d4) by trivial.
- rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H11; inversion H12.
- rewrite H14; rewrite H15; lra.
- + now rewrite H4 in H0.
- + now rewrite H3 in H.
- - Case "VectorMinus"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- specialize (IHdf1 grad grad_env1 grad_env2); intros.
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i => (c * grad i )%R))
- ; intros.
- rewrite H3 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros.
- rewrite H4 in H0; simpl in H0.
- + specialize (IHdf2 (fun i => (- grad i)%R) d1 d2).
- rewrite H3 in IHdf1; rewrite H4 in IHdf1.
- assert (vartlookup d1 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1).
- case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2).
- case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros.
- rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d1 (fun i => (- (c * grad i))%R))
- ; [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d2 (fun i => (- grad i)%R)); [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d1 d5 d0 d3) by trivial.
- rewrite (split_subvar d2 d6 d d4) by trivial.
- replace (fun i => (c * - grad i)%R) with (fun i => (-( c * grad i))%R) in IHdf2.
- * rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H11; inversion H12.
- rewrite H14; rewrite H15; lra.
- * apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- + now rewrite H4 in H0.
- + now rewrite H3 in H.
- - Case "MatrixPlus"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- specialize (IHdf1 grad grad_env1 grad_env2); intros.
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i j => (c * grad i j)%R))
- ; intros.
- rewrite H3 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros.
- rewrite H4 in H0; simpl in H0.
- + specialize (IHdf2 grad d1 d2).
- rewrite H3 in IHdf1; rewrite H4 in IHdf1.
- assert (vartlookup d1 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1).
- case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2).
- case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros.
- rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d1 (fun i j => (c * grad i j)%R))
- ; [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d2 grad); [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d1 d5 d0 d3) by trivial.
- rewrite (split_subvar d2 d6 d d4) by trivial.
- rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H11; inversion H12.
- rewrite H14; rewrite H15; lra.
- + now rewrite H4 in H0.
- + now rewrite H3 in H.
- - Case "MatrixMinus"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- specialize (IHdf1 grad grad_env1 grad_env2); intros.
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i j => (c * grad i j)%R))
- ; intros.
- rewrite H3 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros.
- rewrite H4 in H0; simpl in H0.
- + specialize (IHdf2 (fun i j => (- grad i j)%R) d1 d2).
- rewrite H3 in IHdf1; rewrite H4 in IHdf1.
- assert (vartlookup d1 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1).
- case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2).
- case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros.
- rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d1 (fun i j => (- (c * grad i j))%R))
- ; [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d2 (fun i j => (- grad i j)%R)); [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d1 d5 d0 d3) by trivial.
- rewrite (split_subvar d2 d6 d d4) by trivial.
- replace (fun i j => (c * - grad i j)%R) with
- (fun i j => (-( c * grad i j))%R) in IHdf2.
- * rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H11; inversion H12.
- rewrite H14; rewrite H15; lra.
- * apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- + now rewrite H4 in H0.
- + now rewrite H3 in H.
- - Case "VectorScalMult"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- case_eq ( df_eval σ df1); [|tauto].
- case_eq ( df_eval σ df2); [|tauto]; intros.
- specialize (IHdf1 (vsum (fun j => (d j * grad j)%R)) grad_env1 grad_env2).
- rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1
- (vsum (fun j => (d j * (c * grad j))%R)))
-
- ; intros.
- rewrite H1 in H; rewrite H2 in H; simpl in H.
- rewrite H1 in H0; rewrite H2 in H0; simpl in H0.
- rewrite H5 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2
- (vsum (fun j => (d j * grad j)%R)))
- ; intros.
- rewrite H6 in H0; simpl in H0.
- + specialize (IHdf2 (fun j => (grad j * d0)%R) d3 d4).
- replace
- (c *
- (@vsum floatish_R _
- (fun (j : {n' : nat | (n' < n)%nat}) =>
- d j * grad j)))%R with
- (vsum
- (fun (j : {n' : nat | n' < n}) =>
- (d j * (c * grad j))%R)) in IHdf1.
- * rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1.
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1).
- case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2).
- case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros.
- rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d3
- (fun (j : {n' : nat | n' < n}) =>
- (d0 * (c * grad j))%R))
- ; [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d4 (fun j => (d0 * grad j)%R))
- ; [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d3 d7 d2 d5) by trivial.
- rewrite (split_subvar d4 d8 d1 d6) by trivial.
- replace (fun i => (c * (grad i * d0))%R) with
- (fun j => (d0 * (c * grad j))%R) in IHdf2.
- replace (fun j => (grad j * d0)%R) with (fun j => (d0 * grad j)%R) in IHdf2.
- -- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d3 d2) =
- Some (c * subvar (s, DTfloat) d4 d1)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d7 d5) =
- Some (c * subvar (s, DTfloat) d8 d6)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H13; inversion H14.
- rewrite H16; rewrite H17; lra.
- -- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- -- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- * rewrite vsum_mult; f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- + now rewrite H6 in H0.
- + rewrite H2 in H; rewrite H1 in H.
- now rewrite H5 in H.
- - Case "MatrixScalMult"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- case_eq ( df_eval σ df1); [|tauto].
- case_eq ( df_eval σ df2); [|tauto]; intros.
- specialize (IHdf1 (msum (fun i j => (d i j * grad i j)%R)) grad_env1 grad_env2).
- rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1
- (msum (fun i j => (d i j * (c * grad i j))%R)))
-
- ; intros.
- rewrite H1, H2, H5 in H; simpl in H.
- rewrite H1, H2 in H0; simpl in H0.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2
- (msum (fun i j => (d i j * grad i j)%R)))
- ; intros.
- rewrite H6 in H0; simpl in H0.
- + specialize (IHdf2 (fun i j => (grad i j * d0)%R) d3 d4).
- replace
- (c *
- (@msum floatish_R _ _
- (fun (i : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) =>
- d i j * grad i j)))%R with
-
- (msum
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (d i j * (c * grad i j))%R)) in IHdf1.
- * rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1.
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1).
- case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2).
- case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros.
- rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d3
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (c * grad i j * d0)%R))
- ; [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d4 (fun i j => (grad i j * d0)%R))
- ; [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d3 d7 d2 d5) by trivial.
- rewrite (split_subvar d4 d8 d1 d6) by trivial.
- replace (fun i j => (c * (grad i j * d0))%R) with
- (fun i j => (c * grad i j * d0)%R) in IHdf2.
- -- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d3 d2) =
- Some (c * subvar (s, DTfloat) d4 d1)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d7 d5) =
- Some (c * subvar (s, DTfloat) d8 d6)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H13; inversion H14.
- rewrite H16; rewrite H17; lra.
- -- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- * unfold msum.
- rewrite vsum_mult; f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vmap_nth.
- rewrite vmap_nth.
- rewrite vsum_mult.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- + now rewrite H6 in H0.
- + rewrite H2 in H; rewrite H1 in H.
- now rewrite H5 in H.
- - Case "VectorApply"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- simpl in *.
- case_eq (df_eval σ df2); [ | tauto].
- intros.
- rewrite H3 in H; rewrite H3 in H0; simpl in *.
- match_option_in H0; [|tauto].
- specialize (IHdf1 v0 grad_env1 grad_env2).
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in IHdf1.
- match_option_in H; [|tauto].
- vectoro_assert_forall_in eqq i.
- apply vectoro_to_ovector_forall_some_f; trivial.
- vectoro_assert_forall_in eqq0 i.
- apply vectoro_to_ovector_forall_some_f; trivial.
- assert (v1 = (fun i => (c * v0 i)%R)).
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H4 x).
- specialize (H5 x).
- rewrite vmap_nth in H4; simpl in H4.
- rewrite vmap_nth in H5; simpl in H5.
- match_option_in H4.
- match_option_in H5.
- inversion H4; inversion H5; subst.
- assert (Some d2 = Some d3).
- rewrite <- eqq1.
- rewrite <- eqq2; trivial.
- inversion H6; subst; lra.
- subst.
- apply IHdf1; trivial; discriminate.
- - Case "MatrixApply"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- simpl in *.
- case_eq (df_eval σ df2); [ | tauto].
- intros.
- rewrite H3 in H; rewrite H3 in H0; simpl in *.
- match_option_in H0; [|tauto].
- match_option_in H; [|tauto].
- unfold matrixo_to_omatrix in eqq.
- unfold matrixo_to_omatrix in eqq0.
- vectoro_assert_forall_in eqq i.
- apply vectoro_to_ovector_forall_some_f; trivial.
- vectoro_assert_forall_in eqq0 i.
- apply vectoro_to_ovector_forall_some_f; trivial.
- assert (m1 = (fun i j => (c * m0 i j)%R)).
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H4 x); specialize (H5 x); simpl in H4; simpl in H5.
- vectoro_assert_forall_in H4 j.
- apply vectoro_to_ovector_forall_some_f; trivial.
- vectoro_assert_forall_in H5 j.
- apply vectoro_to_ovector_forall_some_f; trivial.
- specialize (H6 x0); specialize (H7 x0).
- unfold mmap in H6; unfold mmap in H7.
- rewrite vmap_nth in H6; rewrite vmap_nth in H6; simpl in H6.
- rewrite vmap_nth in H7; rewrite vmap_nth in H7; simpl in H7.
- match_case_in H6; intros.
- rewrite H8 in H6; simpl in H6.
- match_case_in H7; intros.
- rewrite H9 in H7; simpl in H7.
- match_option_in H6.
- match_option_in H7.
- inversion H6; inversion H7.
- unfold matrix_zip in H8.
- unfold matrix_zip in H9.
- rewrite vmap_nth in H8.
- rewrite vmap_nth in H9.
- unfold vector_zip in H8.
- unfold vector_zip in H9.
- inversion H8; subst r r0.
- inversion H9; subst r1 r2.
- assert (Some d2 = Some d3).
- rewrite <- eqq1; rewrite <- eqq2; trivial.
- inversion H10; subst; lra.
- specialize (IHdf1 m0 grad_env1 grad_env2).
- rewrite H1, H2 in IHdf1; simpl in IHdf1.
- subst m1.
- apply IHdf1; trivial; discriminate.
- - Case "VLossfun"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- simpl in *.
- case_eq (df_eval σ df2); [ | tauto].
- intros.
- rewrite H3 in H; rewrite H3 in H0; simpl in *.
- match_option_in H0; [|tauto].
- match_option_in H; [|tauto].
- vectoro_assert_forall_in eqq i.
- apply vectoro_to_ovector_forall_some_f; trivial.
- vectoro_assert_forall_in eqq0 i.
- apply vectoro_to_ovector_forall_some_f; trivial.
- assert (v0 = (fun i => (c * v i)%R)).
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H4 x); specialize (H5 x).
- rewrite vmap_nth in H4; simpl in H4.
- rewrite vmap_nth in H5; simpl in H5.
- match_option_in H4.
- match_option_in H5.
- assert (Some d2 = Some d3).
- rewrite <- eqq1; rewrite <- eqq2; trivial.
- inversion H6; subst.
- inversion H4; inversion H5; lra.
- subst.
- specialize (IHdf1 v grad_env1 grad_env2).
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in IHdf1.
- apply IHdf1; trivial; discriminate.
- - Case "MLossfun"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- simpl in *.
- case_eq (df_eval σ df2); [ | tauto].
- intros.
- rewrite H3 in H; rewrite H3 in H0; simpl in *.
- match_option_in H0; [|tauto].
- match_option_in H; [|tauto].
- unfold matrixo_to_omatrix in eqq.
- unfold matrixo_to_omatrix in eqq0.
- vectoro_assert_forall_in eqq i.
- apply vectoro_to_ovector_forall_some_f; trivial.
- vectoro_assert_forall_in eqq0 i.
- apply vectoro_to_ovector_forall_some_f; trivial.
- assert (m1 = (fun i j => (c * m0 i j)%R)).
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H4 x); specialize (H5 x); simpl in H4; simpl in H5.
- vectoro_assert_forall_in H4 j.
- apply vectoro_to_ovector_forall_some_f; trivial.
- vectoro_assert_forall_in H5 j.
- apply vectoro_to_ovector_forall_some_f; trivial.
- specialize (H6 x0); specialize (H7 x0).
- unfold mmap in H6; unfold mmap in H7.
- rewrite vmap_nth in H6; rewrite vmap_nth in H6; simpl in H6.
- rewrite vmap_nth in H7; rewrite vmap_nth in H7; simpl in H7.
- match_destr_in H6.
- match_option_in H6.
- match_option_in H7.
- assert (Some d2 = Some d3).
- rewrite <- eqq1; rewrite <- eqq2; trivial.
- inversion H8; subst.
- inversion H6; inversion H7.
- lra.
- rewrite H6.
- specialize (IHdf1 m0 grad_env1 grad_env2).
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in IHdf1.
- subst.
- apply IHdf1; trivial; discriminate.
- Qed.
-
- Ltac simpl_closed_backprop :=
- match goal with
- | [|- context [
- match df_eval_backprop_deriv ?σ ?df1 ?grad_env1 ?grad with
- | Some _ => _
- | None => _
- end]] => case_eq (df_eval_backprop_deriv σ df1 grad_env1 grad)
- ; [let env := fresh "env" in let eqq := fresh "eqq" in intros env eqq |
- let eqq := fresh "eqq" in
- intros eqq;
- eelim backprop_deriv_fully_closed_not_none; [clear eqq | eapply eqq]; trivial
- ]
- end.
-
- Ltac simpler2 :=
- trivial;
- repeat
- match goal with
- | [ |- Some _ <> None ] => congruence
- | [ |- None <> Some _ ] => congruence
-
- | [H:vartlookup ?grad_env ?a <> None
- |- context [vartlookup ?grad_env ?a]] =>
- case_eq (vartlookup grad_env a); [
- let d := fresh "val" in
- let eqq := fresh "eqq" in
- intros d eqq; simpl in d
- | intros ?; eelim H; solve[auto]]
- | [H: df_eval_backprop_deriv ?σ ?df1 ?grad_env1 _ = Some ?grad_env2
- |- context [vartlookup ?grad_env2 ?a]] =>
- case_eq (vartlookup grad_env2 a); [
- let d := fresh "val" in
- let eqq := fresh "eqq" in
- intros d eqq; simpl in d
- | let eqq := fresh "eqq" in
- intros eqq; eelim df_eval_backprop_deriv_preserves_lookup_not_none; [apply H | idtac | apply eqq]; solve[auto]
- ]
-
- | [H:vartlookup ?grad_env ?a <> None
- |- context [vartlookup ?grad_env ?a]] =>
- case_eq (vartlookup grad_env a); [
- let d := fresh "val" in
- let eqq := fresh "eqq" in
- intros d eqq; simpl in d
- | intros ?; eelim H; solve[auto || congruence]]
- | [H: df_eval_backprop_deriv ?σ ?df1 ?grad_env1 _ = Some ?grad_env2
- |- context [vartlookup ?grad_env2 ?a]] =>
- case_eq (vartlookup grad_env2 a); [
- let d := fresh "val" in
- let eqq := fresh "eqq" in
- intros d eqq; simpl in d
- | let eqq := fresh "eqq" in
- intros eqq; eelim df_eval_backprop_deriv_preserves_lookup_not_none; [apply H | idtac | apply eqq]; solve[auto || congruence]
- ]
- | [H:vartlookup ?grad_env ?a <> None,
- H2:context [match vartlookup ?grad_env ?a with | _ => _ end] |- _] =>
- case_eq (vartlookup grad_env a); [
- let d := fresh "val" in
- let eqq := fresh "eqq" in
- intros d eqq; simpl in d; rewrite eqq in H2
- | intros ?; eelim H; solve[auto || congruence]]
- | [H: df_eval_backprop_deriv ?σ ?df1 ?grad_env1 _ = Some ?grad_env2,
- H2: context [match vartlookup ?grad_env2 ?a with _ => _ end] |- _] =>
- case_eq (vartlookup grad_env2 a); [
- let d := fresh "val" in
- let eqq := fresh "eqq" in
- intros d eqq; simpl in d; rewrite eqq in H2
- | let eqq := fresh "eqq" in
- intros eqq; eelim df_eval_backprop_deriv_preserves_lookup_not_none; [apply H | idtac | apply eqq]; solve[auto || congruence]]
-
- end.
-
- Lemma backprop_indep_env {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (s:SubVar)
- (grad_env1 grad_env2:df_env) (grad : definition_function_types_interp T) :
- let v := (s, DTfloat) in
- vartlookup grad_env1 v <> None ->
- vartlookup grad_env2 v <> None ->
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl ->
- df_eval_backprop_delta σ df v grad_env1 grad =
- df_eval_backprop_delta σ df v grad_env2 grad.
- Proof.
- revert grad_env1 grad_env2.
- unfold df_eval_backprop_delta.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case
- ; simpl; intros grad_env1 grad_env2 neq1 neq2; intros.
- - Case "Number"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- unfold snd in *.
- f_equal.
- unfold subvar; simpl.
- rewrite eqq, eqq0.
- lra.
- - Case "Constant"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- unfold snd in *.
- f_equal.
- unfold subvar; simpl.
- rewrite eqq, eqq0.
- lra.
- - Case "DVector"%string.
- rewrite vforall_forall in H0.
- unfold two_vector_env_iter_alt.
- match_option; [|tauto].
- match_option; [|tauto].
- unfold lift; simpl.
- match_option.
- + match_option.
- f_equal.
- apply (list_env_iter_subvar_env2
- s d d0 {n' : nat | n' < n}
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (grad i))
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (grad i))
- (bounded_seq0 n) grad_env1 grad_env2 d1 d2); trivial.
- * intros.
- specialize (H i (grad i) env1 env2).
- cut_to H; try congruence; eauto 3.
- rewrite H1, H2, H3, H4 in H.
- unfold lift in H.
- now inversion H.
- * intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H1); trivial.
- * intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H1); trivial.
- * assert (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (grad i)) (Some grad_env2)
- (bounded_seq0 n) <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- + assert (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (grad i)) (Some grad_env1)
- (bounded_seq0 n) <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- - Case "DMatrix"%string.
- rewrite vforall_forall in H0.
- unfold two_matrix_env_iter_alt.
- match_option; [|tauto].
- match_option; [|tauto].
- unfold lift; simpl.
- match_option.
- + match_option.
- * f_equal.
- apply (list_env_iter_subvar_env2
- s d d0 {n' : nat | n' < n}
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (Some env) (bounded_seq0 m))
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (Some env) (bounded_seq0 m))
- (bounded_seq0 n) grad_env1 grad_env2 d1 d2); trivial.
- -- intros.
- apply (list_env_iter_subvar_env2
- s v1 v2 {m' : nat | m' < m}
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (bounded_seq0 m) env1 env2 fenv1 fenv2); trivial.
- ++ intros.
- specialize (H i i0 (grad i i0) env0 env3).
- cut_to H.
- rewrite H5, H6, H7, H8 in H.
- unfold lift in H.
- now inversion H.
- congruence.
- congruence.
- specialize (H0 i).
- rewrite vforall_forall in H0.
- apply (H0 i0).
- ++ intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5); trivial.
- ++ intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5); trivial.
- -- intros.
- apply (vartlookup_list_env_iter
- s
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (bounded_seq0 m) env fenv); trivial.
- intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H3); trivial.
- -- intros.
- apply (vartlookup_list_env_iter
- s
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (bounded_seq0 m) env fenv); trivial.
- intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H3); trivial.
- * assert (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (Some env) (bounded_seq0 m)) (Some grad_env2) (bounded_seq0 n)
- <> None).
- apply list_env_iter_total_fun; intros.
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- specialize (H0 a).
- rewrite vforall_forall in H0.
- apply (H0 a0).
- tauto.
- + assert (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (Some env) (bounded_seq0 m)) (Some grad_env1) (bounded_seq0 n)
- <> None).
- apply list_env_iter_total_fun; intros.
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- specialize (H0 a).
- rewrite vforall_forall in H0.
- apply (H0 a0).
- tauto.
- - Case "Var"%string.
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- unfold lift, subvar; simpl.
- destruct (v == (s,DTfloat)).
- + invcs e.
- unfold addvar; simpl.
- rewrite eqq, H0.
- rewrite lookup_update.
- rewrite lookup_update.
- f_equal; lra.
- + assert (v<> (s, DTfloat)) by congruence.
- case_eq (vartlookup grad_env1 v); intros.
- * rewrite lookup_update_neq; trivial.
- rewrite eqq.
- case_eq (vartlookup grad_env2 v); intros.
- -- rewrite lookup_update_neq; trivial.
- rewrite H0; f_equal; lra.
- -- rewrite H0; f_equal; lra.
- * rewrite eqq.
- case_eq (vartlookup grad_env2 v); intros.
- -- rewrite lookup_update_neq; trivial.
- rewrite H0; f_equal; lra.
- -- rewrite H0; f_equal; lra.
- - Case "Plus"%string.
- destruct H.
- unfold lift.
- repeat simpl_closed_backprop.
- simpler2.
- f_equal.
- specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1,eqq3,eqq4 in IHdf1.
- specialize (IHdf2 grad env env1).
- cut_to IHdf2; simpler2.
- unfold lift in IHdf1; inversion IHdf1.
- rewrite eqq2, eqq0 in IHdf2; unfold lift in IHdf2; inversion IHdf2.
- rewrite (split_subvar env env0 val val1); trivial.
- rewrite (split_subvar env1 env2 val0 val2); trivial.
- rewrite H2, H3; lra.
- - Case "Minus"%string.
- destruct H.
- unfold lift.
- repeat simpl_closed_backprop.
- simpler2.
- f_equal.
- specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1,eqq3,eqq4 in IHdf1.
- specialize (IHdf2 (-grad)%R env env1).
- cut_to IHdf2; simpler2.
- unfold lift in IHdf1; invcs IHdf1.
- rewrite eqq0,eqq2 in IHdf2; unfold lift in IHdf2; invcs IHdf2.
- rewrite (split_subvar env env0 val val1); trivial.
- rewrite (split_subvar env1 env2 val0 val2); trivial.
- rewrite H2, H3; lra.
- - Case "Times"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df1 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- unfold lift.
- repeat simpl_closed_backprop.
- simpler2.
- f_equal.
- specialize (IHdf1 (d1 * grad)%R grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq6,eqq2,eqq4 in IHdf1.
- specialize (IHdf2 (d0 * grad)%R env env1).
- cut_to IHdf2; simpler2.
- unfold lift in IHdf1; invcs IHdf1.
- rewrite eqq5, eqq3 in IHdf2; unfold lift in IHdf2; invcs IHdf2.
- rewrite (split_subvar env env0 d val0); trivial.
- rewrite (split_subvar env1 env2 val val1); trivial.
- rewrite H4, H5; lra.
- - Case "Divide"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df1 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- unfold lift.
- repeat simpl_closed_backprop.
- simpler2.
- f_equal.
- specialize (IHdf1 (grad/d1)%R grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq6,eqq2,eqq4 in IHdf1.
- specialize (IHdf2 (-d0/(d1*d1) * grad)%R env env1).
- cut_to IHdf2; simpler2.
- unfold lift in IHdf1; inversion IHdf1.
- rewrite eqq5, eqq3 in IHdf2; unfold lift in IHdf2; inversion IHdf2.
- rewrite (split_subvar env env0 d val0); trivial.
- rewrite (split_subvar env1 env2 val val1); trivial.
- rewrite H4, H5; lra.
- - Case "Square"%string.
- match_option; [|tauto].
- assert (df_eval σ df <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- match_option; [|tauto].
- unfold lift.
- repeat simpl_closed_backprop.
- f_equal.
- specialize (IHdf (2 * d0 * grad)%R grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1,eqq2,eqq3 in IHdf.
- now unfold lift in IHdf; inversion IHdf.
- - Case "Exp"%string.
- match_option; [|tauto].
- assert (df_eval σ df <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- match_option; [|tauto].
- unfold lift.
- repeat simpl_closed_backprop.
- f_equal.
- specialize (IHdf (grad * exp d0)%R grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1,eqq2, eqq3 in IHdf.
- now unfold lift in IHdf; inversion IHdf.
- - Case "Log"%string.
- match_option; [|tauto].
- assert (df_eval σ df <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df grad_env1 (grad/d0)%R <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- unfold lift.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df grad_env2 (grad/d0)%R <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- specialize (IHdf (grad / d0)%R grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1 in IHdf.
- now rewrite eqq2, eqq3 in IHdf; unfold lift in IHdf; inversion IHdf.
- - Case "Abs"%string.
- match_option; [|tauto].
- assert (df_eval σ df <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df grad_env1 (grad * sign d0)%R <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- unfold lift.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df grad_env2 (grad * sign d0)%R <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- specialize (IHdf (grad * sign d0)%R grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1 in IHdf.
- now rewrite eqq2, eqq3 in IHdf; unfold lift in IHdf; inversion IHdf.
- - Case "Sign"%string.
- match_option; [|tauto].
- assert (df_eval σ df <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df grad_env1 (0)%R <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- unfold lift.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df grad_env2 (0)%R <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- specialize (IHdf (0)%R grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq0 in IHdf.
- now rewrite eqq1, eqq2 in IHdf; unfold lift in IHdf; inversion IHdf.
- - Case "PSign"%string.
- match_option; [|tauto].
- assert (df_eval σ df <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df grad_env1 (0)%R <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- unfold lift.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df grad_env2 (0)%R <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- specialize (IHdf (0)%R grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq0 in IHdf.
- now rewrite eqq1, eqq2 in IHdf; unfold lift in IHdf; inversion IHdf.
- - Case "Max"%string.
- destruct H.
- specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H).
- specialize (IHdf2 grad grad_env1 grad_env2 neq1 neq2 H0).
- match_option; [|tauto].
- assert (df_eval σ df1 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- match_option; [|tauto].
- rewrite eqq, eqq2 in IHdf1.
- rewrite eqq, eqq2 in IHdf2.
- destruct (Rle_dec d0 d1); trivial.
- - Case "VectorDot"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df1 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- specialize (IHdf1 (vmap (fun rv : R => (rv * grad)%R) d1)
- grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, H3 in IHdf1.
- assert (df_eval_backprop_deriv σ df1 grad_env1
- (vmap (fun rv : R => (rv * grad)%R) d1) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2
- (vmap (fun rv : R => (rv * grad)%R) d1) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1).
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2).
- specialize (IHdf2 (vmap (fun lv : R => (lv * grad)%R) d0)
- d3 d4 H6 H7 H0).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- unfold lift.
- assert (df_eval_backprop_deriv σ df2 d3
- (vmap (fun lv : R => (lv * grad)%R) d0) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d4
- (vmap (fun lv : R => (lv * grad)%R) d0) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- unfold lift in IHdf1.
- rewrite eqq2, eqq3 in IHdf1; inversion IHdf1.
- rewrite eqq6, eqq7 in IHdf2; inversion IHdf2.
- rewrite (split_subvar d3 d7 d d5); trivial.
- rewrite (split_subvar d4 d8 d2 d6); trivial.
- now rewrite H11, H12.
- - Case "VectorSum"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (ConstVector n grad) grad_env1 grad_env2 neq1 neq2 H).
- now rewrite eqq, eqq0 in IHdf.
- - Case "MatrixSum"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (ConstMatrix m n grad) grad_env1 grad_env2 neq1 neq2 H).
- now rewrite eqq, eqq0 in IHdf.
- - Case "VectorElem"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (fun k : {n' : nat | n' < n} => if equiv_dec (` k) (` i)
- then grad else 0%R)
- grad_env1 grad_env2 neq1 neq2 H).
- now rewrite eqq, eqq0 in IHdf.
- - Case "MatrixElem"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) =>
- if equiv_dec (` k1) (` i) then
- if equiv_dec (` k2) (` j) then grad else 0%R else 0%R)
- grad_env1 grad_env2 neq1 neq2 H).
- now rewrite eqq, eqq0 in IHdf.
- - Case "MatrixVectorMult"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df1 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- specialize (IHdf1 (fun (i : {n' : nat | n' < m})
- (j : {m' : nat | m' < n}) => (grad i * d1 j)%R)
- grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, H3 in IHdf1.
- assert (df_eval_backprop_deriv σ df1 grad_env1
- (fun (i : {n' : nat | n' < m})
- (j : {m' : nat | m' < n}) => (grad i * d1 j)%R)
- <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2
- (fun (i : {n' : nat | n' < m})
- (j : {m' : nat | m' < n}) => (grad i * d1 j)%R)
- <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1).
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2).
- specialize (IHdf2
- (matrix_vector_mult
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => d0 j i) grad)
- d3 d4 H6 H7 H0).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- unfold lift.
- assert (df_eval_backprop_deriv
- σ df2 d3
- (matrix_vector_mult (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => d0 j i)
- grad) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d4
- (matrix_vector_mult (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => d0 j i)
- grad) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- unfold lift in IHdf1.
- unfold lift in IHdf2.
- rewrite eqq2, eqq3 in IHdf1; inversion IHdf1.
- rewrite eqq6, eqq7 in IHdf2; inversion IHdf2.
- rewrite (split_subvar d3 d7 d d5); trivial.
- rewrite (split_subvar d4 d8 d2 d6); trivial.
- now rewrite H11, H12.
- - Case "MatrixVectorAdd"%string.
- destruct H.
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- specialize (IHdf1 grad
- grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, H1 in IHdf1.
- assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); [intros | tauto].
- rewrite eqq0, H4 in IHdf1.
- unfold lift in IHdf1.
- inversion IHdf1.
- assert (list_env_iter
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad i)) (Some d1)
- (bounded_seq0 n) <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- match_option; [|tauto].
- assert (list_env_iter
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad i)) (Some d2)
- (bounded_seq0 n) <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- match_option; [|tauto].
- unfold lift; f_equal.
- assert (vartlookup d1 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1).
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2).
- assert (vartlookup d3 (s, DTfloat) <> None).
- apply
- (vartlookup_list_env_iter
- s
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad i))
- (bounded_seq0 n) d1); trivial.
- intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H10 (s, DTfloat) H11).
- assert (vartlookup d4 (s, DTfloat) <> None).
- apply
- (vartlookup_list_env_iter
- s
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad i))
- (bounded_seq0 n) d2); trivial.
- intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H11 (s, DTfloat) H12).
- unfold subvar in IHdf1; simpl in IHdf1.
- match_option_in IHdf1; [|tauto].
- match_option_in IHdf1; [|tauto].
- rewrite (split_subvar d1 d3 d d5); trivial.
- rewrite (split_subvar d2 d4 d0 d6); trivial.
- rewrite H6.
- apply Rplus_eq_compat_r .
- apply (list_env_iter_subvar_env2
- s d5 d6 {m' : nat | m' < n}
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad i))
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad i))
- (bounded_seq0 n) d1 d2 d3 d4); trivial.
- intros.
- specialize (IHdf2 (transpose grad i) env1 env2).
- rewrite H12, H13 in IHdf2.
- cut_to IHdf2; trivial.
- rewrite H14, H15 in IHdf2.
- unfold lift in IHdf2.
- now inversion IHdf2.
- discriminate.
- discriminate.
- intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H12 (s, DTfloat) H13).
- intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H12 (s, DTfloat) H13).
- - Case "MatrixMult"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df1 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- specialize (IHdf1
- (matrix_mult grad (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < p}) => d1 j i))
- grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, H3 in IHdf1.
- assert (df_eval_backprop_deriv σ df1 grad_env1
- (matrix_mult grad (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < p}) => d1 j i)) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2
- (matrix_mult grad (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < p}) => d1 j i)) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- rewrite eqq2, eqq3 in IHdf1.
- unfold lift in IHdf1; inversion IHdf1.
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1).
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2).
- specialize (IHdf2
- (matrix_mult (fun (i : {n' : nat | n' < p})
- (j : {m' : nat | m' < m}) => d0 j i) grad)
- d3 d4 H6 H8 H0).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- unfold lift.
- assert (df_eval_backprop_deriv
- σ df2 d3
- (matrix_mult (fun (i : {n' : nat | n' < p})
- (j : {m' : nat | m' < m}) => d0 j i) grad) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv
- σ df2 d4
- (matrix_mult (fun (i : {n' : nat | n' < p})
- (j : {m' : nat | m' < m}) => d0 j i) grad) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- unfold lift in IHdf1.
- unfold lift in IHdf2.
- rewrite eqq6, eqq7 in IHdf2; inversion IHdf2.
- rewrite (split_subvar d3 d7 d d5); trivial.
- rewrite (split_subvar d4 d8 d2 d6); trivial.
- now rewrite H7, H12.
- - Case "VectorPlus"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- unfold lift.
- assert (df_eval_backprop_deriv σ df2 d0 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d2 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1 in IHdf1.
- specialize (IHdf2 grad d0 d2).
- assert (vartlookup d0 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1).
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq2).
- specialize (IHdf2 H5 H6 H0).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- rewrite eqq0, eqq2 in IHdf1; unfold lift in IHdf1; inversion IHdf1.
- rewrite eqq3, eqq4 in IHdf2; unfold lift in IHdf2; inversion IHdf2.
- rewrite (split_subvar d0 d3 d d5); trivial.
- rewrite (split_subvar d2 d4 d1 d6); trivial.
- lra.
- - Case "VectorMinus"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- unfold lift.
- assert (df_eval_backprop_deriv
- σ df2 d0
- (fun i : {n' : nat | n' < n} => (- grad i)%R) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d2
- (fun i : {n' : nat | n' < n} => (- grad i)%R) <> None)
- by (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1 in IHdf1.
- specialize (IHdf2 (fun i : {n' : nat | n' < n} => (- grad i)%R) d0 d2).
- assert (vartlookup d0 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1).
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq2).
- specialize (IHdf2 H5 H6 H0).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- rewrite eqq0, eqq2 in IHdf1; unfold lift in IHdf1; inversion IHdf1.
- rewrite eqq3, eqq4 in IHdf2; unfold lift in IHdf2; inversion IHdf2.
- rewrite (split_subvar d0 d3 d d5); trivial.
- rewrite (split_subvar d2 d4 d1 d6); trivial.
- lra.
- - Case "MatrixPlus"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- unfold lift.
- assert (df_eval_backprop_deriv σ df2 d0 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d2 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1 in IHdf1.
- specialize (IHdf2 grad d0 d2).
- assert (vartlookup d0 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1).
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq2).
- specialize (IHdf2 H5 H6 H0).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- rewrite eqq0, eqq2 in IHdf1; unfold lift in IHdf1; inversion IHdf1.
- rewrite eqq3, eqq4 in IHdf2; unfold lift in IHdf2; inversion IHdf2.
- rewrite (split_subvar d0 d3 d d5); trivial.
- rewrite (split_subvar d2 d4 d1 d6); trivial.
- lra.
- - Case "MatrixMinus"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- unfold lift.
- assert (df_eval_backprop_deriv
- σ df2 d0
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad i j)%R)
- <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d2
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad i j)%R)
- <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1 in IHdf1.
- specialize (IHdf2
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad i j)%R)
- d0 d2).
- assert (vartlookup d0 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1).
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq2).
- specialize (IHdf2 H5 H6 H0).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- rewrite eqq0, eqq2 in IHdf1; unfold lift in IHdf1; inversion IHdf1.
- rewrite eqq3, eqq4 in IHdf2; unfold lift in IHdf2; inversion IHdf2.
- rewrite (split_subvar d0 d3 d d5); trivial.
- rewrite (split_subvar d2 d4 d1 d6); trivial.
- lra.
- - Case "VectorScalMult"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df1 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- specialize (IHdf1
- (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad j)%R))
- grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, H3 in IHdf1.
- assert (df_eval_backprop_deriv
- σ df1 grad_env1
- (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad j)%R)) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2
- (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad j)%R))
- <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1).
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2).
- specialize (IHdf2 (fun j : {n' : nat | n' < n} => (d0 * grad j)%R)
- d3 d4 H6 H7 H0).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- unfold lift.
- assert (df_eval_backprop_deriv σ df2 d3
- (fun j : {n' : nat | n' < n} => (d0 * grad j)%R)
- <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d4
- (fun j : {n' : nat | n' < n} => (d0 * grad j)%R)
- <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- unfold lift in IHdf1.
- rewrite eqq2, eqq3 in IHdf1; inversion IHdf1.
- rewrite eqq6, eqq7 in IHdf2; inversion IHdf2.
- rewrite (split_subvar d3 d7 d d5); trivial.
- rewrite (split_subvar d4 d8 d2 d6); trivial.
- now rewrite H11, H12.
- - Case "MatrixScalMult"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df1 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- specialize (IHdf1
- (msum
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (d1 i j * grad i j)%R))
- grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, H3 in IHdf1.
- assert (df_eval_backprop_deriv
- σ df1 grad_env1
- (msum
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (d1 i j * grad i j)%R)) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2
- (msum
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (d1 i j * grad i j)%R)) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1).
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2).
- specialize (IHdf2 (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad i j * d0)%R)
- d3 d4 H6 H7 H0).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- unfold lift.
- assert (df_eval_backprop_deriv σ df2 d3
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad i j * d0)%R)
- <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d4
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad i j * d0)%R)
- <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- unfold lift in IHdf1.
- rewrite eqq2, eqq3 in IHdf1; inversion IHdf1.
- rewrite eqq6, eqq7 in IHdf2; inversion IHdf2.
- rewrite (split_subvar d3 d7 d d5); trivial.
- rewrite (split_subvar d4 d8 d2 d6); trivial.
- now rewrite H11, H12.
- - Case "VectorApply"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- match_option.
- specialize (IHdf1 v0 grad_env1 grad_env2 neq1 neq2 H0).
- now rewrite eqq, H2 in IHdf1.
- - Case "MatrixApply"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- match_option.
- specialize (IHdf1 m0 grad_env1 grad_env2 neq1 neq2 H0).
- now rewrite eqq, H2 in IHdf1.
- - Case "VLossfun"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- match_option.
- specialize (IHdf1 v grad_env1 grad_env2 neq1 neq2 H0).
- now rewrite eqq, H2 in IHdf1.
- - Case "MLossfun"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- match_option.
- specialize (IHdf1 m0 grad_env1 grad_env2 neq1 neq2 H0).
- now rewrite eqq, H2 in IHdf1.
- Qed.
-
- Lemma backprop_exchange_order {T} (σ:df_env) (df1 df2 :DefinedFunction UnitAnn T) (s: SubVar)
- (env:df_env) (grad1 grad2 : definition_function_types_interp T) :
- let v := (s, DTfloat) in
- vartlookup env v <> None ->
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df1 vl ->
- fully_closed_over df2 vl ->
- match
- df_eval_backprop_deriv σ df1 env grad1, df_eval_backprop_deriv σ df2 env grad2, vartlookup env v with
- | Some env1, Some env2, Some oval =>
- lift (fun e => subvar v e oval)
- (df_eval_backprop_deriv σ df1 env2 grad1) =
- lift (fun e => subvar v e oval)
- (df_eval_backprop_deriv σ df2 env1 grad2)
- | _, _, _ => True
- end.
- Proof.
- intros.
- do 3 match_option.
- unfold lift.
- do 2 match_option.
- - assert (vartlookup d0 v <> None); simpler2.
- assert (vartlookup d v <> None); simpler2.
- case_eq (vartlookup d0 v); [intros|tauto].
- case_eq (vartlookup d v); [intros|tauto].
- f_equal. subst v.
- rewrite (split_subvar d0 d2 d1 d4); trivial.
- rewrite (split_subvar d d3 d1 d5); trivial.
- generalize (backprop_indep_env σ df1 s env d0 grad1); intros.
- generalize (backprop_indep_env σ df2 s env d grad2); intros.
- simpl in H6; simpl in H7.
- cut_to H6; trivial; try congruence.
- cut_to H7; trivial; try congruence.
- unfold df_eval_backprop_delta in *.
- rewrite eqq1, H4 in H6.
- rewrite eqq1, H5 in H7.
- unfold lift in H6; simpl in H6.
- unfold lift in H7; simpl in H7.
- rewrite eqq, eqq2 in H6.
- rewrite eqq0, eqq3 in H7.
- inversion H6; inversion H7.
- rewrite H9, H10; lra.
- - assert (df_eval_backprop_deriv σ df2 d grad2 <> None).
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- - assert (df_eval_backprop_deriv σ df1 d0 grad1 <> None).
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- Qed.
-
- Lemma list_env_iter_id {A} (env : df_env) (l : list A) :
- list_env_iter (fun (_ : A) (env : df_env) => Some env)
- (Some env) l = Some env.
- Proof.
- now induction l.
- Qed.
-
- Lemma backprop_grad_sum_list_env_iter {m} (σ:df_env)
- (vecdf:Vector (DefinedFunction UnitAnn DTfloat) m) (s: SubVar)
- (grad_env1 grad_env2 grad_env3:df_env)
- (grad1 grad2 : (Vector float m))
- (val1 val2 val3 : float)
- (l : list {m' | m' < m} )
- :
- let v := (s, DTfloat) in
- (forall (i:{m' | m' < m}),
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over (vecdf i) vl) ->
- (forall (i:{m' | m' < m}) (env1 env2 env3 : df_env),
- vartlookup env1 v <> None ->
- vartlookup env2 v <> None ->
- vartlookup env3 v <> None ->
-
- df_eval_backprop_delta σ (vecdf i) v env3
- (grad1 i + grad2 i)%R =
- lift2 dfti_gen_plus
- (df_eval_backprop_delta σ (vecdf i) v env1 (grad1 i))
- (df_eval_backprop_delta σ (vecdf i) v env2 (grad2 i))) ->
-
- vartlookup grad_env1 v = Some val1 ->
- vartlookup grad_env2 v = Some val2 ->
- vartlookup grad_env3 v = Some val3 ->
-
- lift (fun e : df_env => subvar v e val3)
- (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env (grad1 j + grad2 j)%R)
- (Some grad_env3) l) =
- lift2 dfti_gen_plus
- (lift (fun e : df_env => subvar v e val1)
- (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env (grad1 j)) (Some grad_env1) l))
- (lift (fun e : df_env => subvar v e val2)
- (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env (grad2 j)) (Some grad_env2) l)).
- Proof.
- intros.
- revert val1 val2 val3 grad_env1 grad_env2 grad_env3 H1 H2 H3.
- induction l.
- - intros.
- simpl; f_equal.
- unfold subvar; simpl.
- rewrite H1,H2,H3.
- lra.
- - intros.
- simpl.
- unfold df_eval_backprop_delta in H0.
- unfold lift, lift2.
- assert (df_eval_backprop_deriv σ (vecdf a) grad_env3 (grad1 a + grad2 a)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- case_eq (df_eval_backprop_deriv σ (vecdf a) grad_env3 (grad1 a + grad2 a)%R)
- ; [intros | tauto].
- assert (df_eval_backprop_deriv σ (vecdf a) grad_env1 (grad1 a)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- case_eq (df_eval_backprop_deriv σ (vecdf a) grad_env1 (grad1 a)%R)
- ; [intros | tauto].
- assert (df_eval_backprop_deriv σ (vecdf a) grad_env2 (grad2 a)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- case_eq (df_eval_backprop_deriv σ (vecdf a) grad_env2 (grad2 a)%R)
- ; [intros | tauto].
-
- assert (vartlookup d v <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 v); congruence.
- assert (vartlookup d0 v <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H7 v); congruence.
- assert (vartlookup d1 v <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H9 v); congruence.
-
- case_eq (vartlookup d v); [intros v3 eq3 |tauto].
- case_eq (vartlookup d0 v); [intros v1 eq1 |tauto].
- case_eq (vartlookup d1 v); [intros v2 eq2 |tauto].
-
- specialize (IHl v1 v2 v3 d0 d1 d eq1 eq2 eq3).
- unfold lift, lift2.
- match_option.
- case_eq (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env (grad1 j)) (Some d0) l).
- case_eq (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env (grad2 j)) (Some d1) l).
- + intros.
- rewrite eqq, H13, H14 in IHl.
- unfold lift, lift2 in IHl; simpl in IHl.
- f_equal.
- subst v.
- rewrite (split_subvar d d2 val3 v3); trivial.
- rewrite (split_subvar d0 d4 val1 v1); trivial.
- rewrite (split_subvar d1 d3 val2 v2); trivial.
- inversion IHl.
- rewrite H16.
- specialize (H0 a grad_env1 grad_env2 grad_env3).
- cut_to H0; try congruence.
- rewrite H1,H2,H3 in H0.
- rewrite H5,H7,H9 in H0.
- unfold lift, lift2 in H0.
- inversion H0.
- rewrite H17; lra.
- + intros.
- assert (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env (grad2 j)) (Some d1) l <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- tauto.
- + intros.
- assert (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env (grad1 j)) (Some d0) l <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- tauto.
- + assert (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env (grad1 j + grad2 j)%R) (Some d) l <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- tauto.
- Qed.
-
- Lemma backprop_mat_grad_sum_list_env_iter {m n} (σ:df_env)
- (df : DefinedFunction UnitAnn (DTVector n)) (s: SubVar)
- (grad_env1 grad_env2 grad_env3:df_env)
- (grad1 grad2 : (Matrix float m n))
- (val1 val2 val3 : float)
- (l : list {m' | m' < m} )
- :
- let v := (s, DTfloat) in
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl ->
- (forall (i:{m' | m' < m}) (env1 env2 env3 : df_env),
- vartlookup env1 v <> None ->
- vartlookup env2 v <> None ->
- vartlookup env3 v <> None ->
-
- df_eval_backprop_delta σ df v env3 (fun j => (grad1 i j + grad2 i j)%R) =
- lift2 dfti_gen_plus
- (df_eval_backprop_delta σ df v env1 (grad1 i))
- (df_eval_backprop_delta σ df v env2 (grad2 i))) ->
-
- vartlookup grad_env1 v = Some val1 ->
- vartlookup grad_env2 v = Some val2 ->
- vartlookup grad_env3 v = Some val3 ->
-
- lift (fun e : df_env => subvar v e val3)
- (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ df env (fun k => (grad1 j k + grad2 j k)%R) )
- (Some grad_env3) l) =
- lift2 dfti_gen_plus
- (lift (fun e : df_env => subvar v e val1)
- (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ df env (grad1 j)) (Some grad_env1) l))
- (lift (fun e : df_env => subvar v e val2)
- (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ df env (grad2 j)) (Some grad_env2) l)).
- Proof.
- intros.
- revert val1 val2 val3 grad_env1 grad_env2 grad_env3 H1 H2 H3.
- induction l.
- - intros.
- simpl; f_equal.
- unfold subvar; simpl.
- rewrite H1,H2,H3.
- lra.
- - intros.
- unfold df_eval_backprop_delta in *.
- simpl.
- unfold lift, lift2.
- assert (df_eval_backprop_deriv σ df grad_env3
- (fun k : {n' : nat | n' < n} => (grad1 a k + grad2 a k)%R) <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- case_eq (df_eval_backprop_deriv σ df grad_env3
- (fun k : {n' : nat | n' < n} => (grad1 a k + grad2 a k)%R))
- ; [intros | tauto].
- assert (df_eval_backprop_deriv σ df grad_env1 (grad1 a)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- case_eq (df_eval_backprop_deriv σ df grad_env1 (grad1 a)%R)
- ; [intros | tauto].
- assert (df_eval_backprop_deriv σ df grad_env2 (grad2 a)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- case_eq (df_eval_backprop_deriv σ df grad_env2 (grad2 a)%R)
- ; [intros | tauto].
-
- assert (vartlookup d v <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 v); congruence.
- assert (vartlookup d0 v <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H7 v); congruence.
- assert (vartlookup d1 v <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H9 v); congruence.
-
- case_eq (vartlookup d v); [intros v3 eq3 |tauto].
- case_eq (vartlookup d0 v); [intros v1 eq1 |tauto].
- case_eq (vartlookup d1 v); [intros v2 eq2 |tauto].
-
- specialize (IHl v1 v2 v3 d0 d1 d eq1 eq2 eq3).
- match_option.
- case_eq (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ df env (grad1 j)) (Some d0) l).
- case_eq (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ df env (grad2 j)) (Some d1) l).
- + intros.
- rewrite eqq, H13, H14 in IHl.
- unfold lift, lift2 in IHl; simpl in IHl.
- f_equal.
- subst v.
- rewrite (split_subvar d d2 val3 v3); trivial.
- rewrite (split_subvar d0 d4 val1 v1); trivial.
- rewrite (split_subvar d1 d3 val2 v2); trivial.
- inversion IHl.
- rewrite H16.
- specialize (H0 a grad_env1 grad_env2 grad_env3).
- cut_to H0; try congruence.
- rewrite H1,H2,H3 in H0.
- rewrite H5,H7,H9 in H0.
- unfold lift, lift2 in H0.
- inversion H0.
- rewrite H17; lra.
- + intros.
- assert (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ df env (grad2 j)) (Some d1) l <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- tauto.
- + intros.
- assert (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ df env (grad1 j)) (Some d0) l <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- tauto.
- + assert (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ df env
- (fun k : {n' : nat | n' < n} => (grad1 j k + grad2 j k)%R))
- (Some d) l <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- tauto.
- Qed.
-
- Lemma matrix_zip_m_n {T} {m n} {i j} {m1 m2 : Matrix T m n} :
- matrix_zip m1 m2 i j = (m1 i j, m2 i j).
- Proof.
- unfold matrix_zip.
- rewrite vmap_nth; simpl.
- now unfold vector_zip.
- Qed.
-
- Lemma backprop_grad_sum {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (s: SubVar)
- (grad_env1 grad_env2 grad_env3:df_env)
- (grad1 grad2 : definition_function_types_interp T) :
- let v := (s, DTfloat) in
- vartlookup grad_env1 v <> None ->
- vartlookup grad_env2 v <> None ->
- vartlookup grad_env3 v <> None ->
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl ->
- df_eval_backprop_delta σ df (s,DTfloat) grad_env3 (dfti_gen_plus grad1 grad2) =
- lift2 dfti_gen_plus
- (df_eval_backprop_delta σ df (s,DTfloat) grad_env1 grad1)
- (df_eval_backprop_delta σ df (s,DTfloat) grad_env2 grad2).
- Proof.
- unfold df_eval_backprop_delta.
- revert grad_env1 grad_env2 grad_env3.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case
-(*
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case
-*)
- ; simpl; intros.
- - Case "Number"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- unfold subvar; simpl.
- rewrite eqq, eqq0, eqq1.
- f_equal; lra.
- - Case "Constant"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- unfold subvar; simpl.
- rewrite eqq, eqq0, eqq1.
- f_equal; lra.
- - Case "DVector"%string.
- unfold two_vector_env_iter_alt in *.
- rewrite vforall_forall in H3.
- revert grad_env1 grad_env2 grad_env3 H0 H1 H2.
- induction (bounded_seq0 n).
- + intros.
- simpl.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- unfold subvar; simpl.
- rewrite eqq, eqq0, eqq1.
- f_equal; lra.
- + intros.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- unfold lift; simpl in *.
- assert (df_eval_backprop_deriv σ (x a) grad_env3 (grad1 a + grad2 a)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- assert (df_eval_backprop_deriv σ (x a) grad_env1 (grad1 a)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- assert (df_eval_backprop_deriv σ (x a) grad_env2 (grad2 a)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- case_eq (df_eval_backprop_deriv σ (x a) grad_env3 (grad1 a + grad2 a)%R)
- ; [intros|tauto].
- case_eq (df_eval_backprop_deriv σ (x a) grad_env1 (grad1 a)%R)
- ; [intros|tauto].
- case_eq (df_eval_backprop_deriv σ (x a) grad_env2 (grad2 a)%R)
- ; [intros|tauto].
- assert (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (grad1 i + grad2 i)%R)
- (Some d2) l <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- apply (H3 a0).
- match_option; [|tauto].
- assert (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (grad1 i)) (Some d3)
- l <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- apply (H3 a0).
- match_option; [|tauto].
- assert (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (grad2 i)) (Some d4)
- l <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- apply (H3 a0).
- match_option; [|tauto].
- unfold lift2; simpl.
-
- assert (vartlookup d2 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H7 (s, DTfloat) H2).
- assert (vartlookup d3 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H8 (s, DTfloat) H0).
- assert (vartlookup d4 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H9 (s, DTfloat) H1).
-
- case_eq (vartlookup d2 (s, DTfloat)); [intros|tauto].
- case_eq (vartlookup d3 (s, DTfloat)); [intros|tauto].
- case_eq (vartlookup d4 (s, DTfloat)); [intros|tauto].
-
- rewrite (split_subvar d2 d5 d d8); trivial.
- rewrite (split_subvar d3 d6 d0 d9); trivial.
- rewrite (split_subvar d4 d7 d1 d10); trivial.
-
- f_equal.
-
- specialize (IHl d3 d4 d2 H14 H15 H13).
- rewrite H16, H17, H18 in IHl.
- rewrite eqq2, eqq3, eqq4 in IHl.
- unfold lift, lift2 in IHl.
- inversion IHl.
- rewrite H20.
-
- specialize (H a (grad1 a) (grad2 a) grad_env1 grad_env2 grad_env3 H0 H1 H2).
- specialize (H (H3 a)).
-
- rewrite eqq,eqq0,eqq1 in H.
- rewrite H7, H8, H9 in H.
- unfold lift, lift2 in H.
- inversion H.
- rewrite H21.
- lra.
- - Case "DMatrix"%string.
- unfold two_matrix_env_iter_alt in *.
- rewrite vforall_forall in H3.
- revert grad_env1 grad_env2 grad_env3 H0 H1 H2.
- induction (bounded_seq0 n); induction (bounded_seq0 m).
- + intros.
- simpl.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- unfold subvar; simpl.
- rewrite eqq, eqq0, eqq1.
- f_equal; lra.
- + intros.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- simpl.
- unfold lift, lift2; simpl; f_equal.
- unfold subvar; simpl.
- rewrite eqq, eqq0, eqq1; simpl; lra.
- + intros.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- rewrite list_env_iter_id.
- rewrite list_env_iter_id.
- rewrite list_env_iter_id.
- unfold subvar; simpl.
- rewrite eqq, eqq0, eqq1; simpl; f_equal; lra.
- + intros.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- simpl.
- assert (df_eval_backprop_deriv σ (x a a0) grad_env3 (grad1 a a0 + grad2 a a0)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a); rewrite vforall_forall in H3; apply H3.
- assert (df_eval_backprop_deriv σ (x a a0) grad_env1 (grad1 a a0)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a); rewrite vforall_forall in H3; apply H3.
- assert (df_eval_backprop_deriv σ (x a a0) grad_env2 (grad2 a a0)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a); rewrite vforall_forall in H3; apply H3.
- case_eq (df_eval_backprop_deriv σ (x a a0) grad_env3 (grad1 a a0 + grad2 a a0)%R)
- ; [intros|tauto].
- case_eq (df_eval_backprop_deriv σ (x a a0) grad_env1 (grad1 a a0)%R)
- ; [intros|tauto].
- case_eq (df_eval_backprop_deriv σ (x a a0) grad_env2 (grad2 a a0)%R)
- ; [intros|tauto].
- assert (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (x a j) env (grad1 a j + grad2 a j)%R)
- (Some d2) l0 <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a); rewrite vforall_forall in H3; apply H3.
- case_eq (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (x a j) env (grad1 a j + grad2 a j)%R)
- (Some d2) l0 ); [intros | tauto].
- assert (list_env_iter
- (fun (i : {n' : nat | n' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (x a i) env (grad1 a i)) (Some d3)
- l0 <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a); rewrite vforall_forall in H3; apply H3.
- case_eq (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (x a j) env (grad1 a j)%R)
- (Some d3) l0 ); [intros | tauto].
- assert (list_env_iter
- (fun (i : {n' : nat | n' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (x a i) env (grad2 a i)) (Some d4)
- l0 <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a); rewrite vforall_forall in H3; apply H3.
- case_eq (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (x a j) env (grad2 a j)%R)
- (Some d4) l0 ); [intros | tauto].
- assert
- (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad1 i j + grad2 i j)%R)
- (df_eval_backprop_deriv σ (x i a0) env (grad1 i a0 + grad2 i a0)%R) l0)
- (Some d5) l <> None).
- apply list_env_iter_total_fun; intros.
- assert (df_eval_backprop_deriv σ (x a1 a0) env0 (grad1 a1 a0 + grad2 a1 a0)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a1); rewrite vforall_forall in H3; apply H3.
- case_eq (df_eval_backprop_deriv σ (x a1 a0) env0 (grad1 a1 a0 + grad2 a1 a0)%R); [intros|tauto].
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a1); rewrite vforall_forall in H3; apply H3.
- unfold lift; simpl.
- match_option; [|tauto].
-
- assert
- (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad1 i j)%R)
- (df_eval_backprop_deriv σ (x i a0) env (grad1 i a0)%R) l0)
- (Some d6) l <> None).
- apply list_env_iter_total_fun; intros.
- assert (df_eval_backprop_deriv σ (x a1 a0) env0 (grad1 a1 a0)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a1); rewrite vforall_forall in H3; apply H3.
- case_eq (df_eval_backprop_deriv σ (x a1 a0) env0 (grad1 a1 a0)%R); [intros|tauto].
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a1); rewrite vforall_forall in H3; apply H3.
- match_option; [|tauto].
-
- assert
- (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad2 i j)%R)
- (df_eval_backprop_deriv σ (x i a0) env (grad2 i a0)%R) l0)
- (Some d7) l <> None).
- apply list_env_iter_total_fun; intros.
- assert (df_eval_backprop_deriv σ (x a1 a0) env0 (grad2 a1 a0)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a1); rewrite vforall_forall in H3; apply H3.
- case_eq (df_eval_backprop_deriv σ (x a1 a0) env0 (grad2 a1 a0)%R); [intros|tauto].
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a1); rewrite vforall_forall in H3; apply H3.
- match_option; [|tauto].
-
- unfold lift2; simpl; f_equal.
-
- assert (vartlookup d2 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H7 (s, DTfloat) H2).
- assert (vartlookup d3 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H8 (s, DTfloat) H0).
- assert (vartlookup d4 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H9 (s, DTfloat) H1).
-
- case_eq (vartlookup d2 (s, DTfloat)); [intros|tauto].
- case_eq (vartlookup d3 (s, DTfloat)); [intros|tauto].
- case_eq (vartlookup d4 (s, DTfloat)); [intros|tauto].
-
- assert (Hc := H).
- assert (H3c := H3).
-
- specialize (H a a0 (grad1 a a0) (grad2 a a0) grad_env1 grad_env2 grad_env3
- H0 H1 H2).
-
- specialize (H3 a).
- rewrite vforall_forall in H3.
- specialize (H (H3 a0)).
- rewrite eqq, eqq0, eqq1 in H; simpl in H.
- rewrite H7, H8, H9 in H; unfold lift, lift2 in H; simpl in H.
-
- assert (vartlookup d5 (s, DTfloat) <> None).
- apply (vartlookup_list_env_iter
- s (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (x a j) env (grad1 a j + grad2 a j)%R)
- l0 d2 d5); trivial; intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H25 (s, DTfloat) H26).
- case_eq (vartlookup d5 (s, DTfloat)); [intros|tauto].
- assert (vartlookup d6 (s, DTfloat) <> None).
- apply (vartlookup_list_env_iter
- s (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (x a j) env (grad1 a j)%R)
- l0 d3 d6); trivial; intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H27 (s, DTfloat) H28).
- case_eq (vartlookup d6 (s, DTfloat)); [intros|tauto].
- assert (vartlookup d7 (s, DTfloat) <> None).
- apply (vartlookup_list_env_iter
- s (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (x a j) env (grad2 a j)%R)
- l0 d4 d7); trivial; intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H29 (s, DTfloat) H30).
- case_eq (vartlookup d7 (s, DTfloat)); [intros|tauto].
-
- specialize (IHl d6 d7 d5 H27 H29 H25).
- rewrite H28, H30, H26 in IHl; simpl in IHl.
-
- rewrite eqq2,eqq3,eqq4 in IHl.
- unfold lift, lift2 in IHl; simpl in IHl.
-
- rewrite (split_subvar d5 d8 d d14); trivial.
- rewrite (split_subvar d6 d9 d0 d15); trivial.
- rewrite (split_subvar d7 d10 d1 d16); trivial.
-
- rewrite (split_subvar d2 d5 d d11); trivial.
- rewrite (split_subvar d3 d6 d0 d12); trivial.
- rewrite (split_subvar d4 d7 d1 d13); trivial.
-
- inversion H.
- inversion IHl.
- rewrite H32, H33.
-
- generalize (backprop_grad_sum_list_env_iter
- σ (x a) s d3 d4 d2 (grad1 a) (grad2 a)
- d12 d13 d11 l0); intros.
- specialize (H31 H3).
- cut_to H31.
- * rewrite H11,H13,H15 in H31.
- unfold lift, lift2 in H31.
- inversion H31.
- rewrite H35; lra.
- * intros.
- unfold df_eval_backprop_delta.
- specialize (Hc a i (grad1 a i) (grad2 a i) env1 env2 env3).
- specialize (Hc H34 H35 H36).
- specialize (Hc (H3 i)).
- apply Hc.
- * trivial.
- * trivial.
- * trivial.
- - Case "Var"%string.
- match_option; [|tauto].
- case_eq (vartlookup grad_env1 (s, DTfloat)); [intros| tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros| tauto].
- destruct (v == (s,DTfloat)).
- + invcs e.
- rewrite eqq, H3, H4; simpl.
- f_equal.
- rewrite subvar_addvar_scalar_eq; trivial.
- rewrite subvar_addvar_scalar_eq; trivial.
- rewrite subvar_addvar_scalar_eq; trivial.
- + assert (v<> (s, DTfloat)) by congruence.
- match_option.
- * match_option.
- -- match_option; unfold lift, lift2; simpl; f_equal.
- ++ rewrite subvar_addvar_scalar_neq; trivial.
- rewrite subvar_addvar_scalar_neq; trivial.
- rewrite subvar_addvar_scalar_neq; trivial.
- lra.
- ++ rewrite subvar_addvar_scalar_neq; trivial.
- rewrite subvar_addvar_scalar_neq; trivial.
- unfold subvar; simpl; rewrite H4; lra.
- -- unfold lift, lift2; simpl; f_equal.
- rewrite subvar_addvar_scalar_neq; trivial.
- case_eq (vartlookup grad_env2 v); intros.
- ++ rewrite subvar_addvar_scalar_neq; trivial.
- f_equal; unfold subvar; simpl.
- rewrite H3; lra.
- ++ unfold subvar; simpl.
- rewrite H3, H4.
- f_equal; lra.
- * match_option.
- -- match_option; unfold lift, lift2; simpl; f_equal.
- ++ rewrite subvar_addvar_scalar_neq; trivial.
- rewrite subvar_addvar_scalar_neq; trivial.
- unfold subvar; simpl.
- rewrite eqq; lra.
- ++ rewrite subvar_addvar_scalar_neq; trivial.
- unfold subvar; simpl.
- rewrite eqq, H4; lra.
- -- match_option; unfold lift, lift2; simpl; f_equal.
- ++ rewrite subvar_addvar_scalar_neq; trivial.
- unfold subvar; simpl.
- rewrite eqq, H3; lra.
- ++ unfold subvar; simpl.
- rewrite eqq, H3, H4; lra.
- - Case "Plus"%string.
- destruct H2.
- unfold lift, lift2.
- repeat simpl_closed_backprop.
- simpler2.
- specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2 grad1 grad2 env1 env3 env).
- simpl in IHdf1.
- rewrite eqq5, eqq6, eqq7 in IHdf1.
- rewrite eqq, eqq1, eqq3 in IHdf1.
- unfold lift, lift2 in IHdf1; simpl in IHdf1.
- inversion IHdf1.
- cut_to IHdf2; simpler2.
- simpl in IHdf2.
- rewrite eqq0, eqq2, eqq4 in IHdf2.
- unfold lift, lift2 in IHdf2.
- inversion IHdf2.
- f_equal.
- rewrite (split_subvar env env0 val val2); trivial.
- rewrite (split_subvar env1 env2 val0 val3); trivial.
- rewrite (split_subvar env3 env4 val1 val4); trivial.
- rewrite H5, H6; lra.
- - Case "Minus"%string.
- destruct H2.
- unfold lift, lift2.
- repeat simpl_closed_backprop.
- simpler2.
- specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2 (-grad1)%R (-grad2)%R env1 env3 env).
- simpl in IHdf1.
- rewrite eqq5, eqq6, eqq7 in IHdf1.
- rewrite eqq, eqq1, eqq3 in IHdf1.
- unfold lift, lift2 in IHdf1; simpl in IHdf1.
- inversion IHdf1.
- cut_to IHdf2; simpler2.
- simpl in IHdf2.
- replace (- grad1 + - grad2)%R with (- (grad1 + grad2))%R in IHdf2 by lra.
- rewrite eqq0, eqq2, eqq4 in IHdf2.
- unfold lift, lift2 in IHdf2; simpl in IHdf2.
- inversion IHdf2.
- rewrite (split_subvar env env0 val val2); trivial.
- rewrite (split_subvar env1 env2 val0 val3); trivial.
- rewrite (split_subvar env3 env4 val1 val4); trivial.
- f_equal; rewrite H5, H6; lra.
- - Case "Times"%string.
- destruct H2.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df1); intros.
- specialize (H4 H2).
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H5 H3).
- match_option; [|tauto].
- unfold lift, lift2.
- repeat simpl_closed_backprop.
- simpler2.
- specialize (IHdf1 (d1 * grad1)%R (d1 * grad2)%R grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2 (d0 * grad1)%R (d0 * grad2)%R env1 env3 env).
- rewrite eqq8, eqq9, eqq in IHdf1.
- simpl in IHdf1.
- replace (d1 *grad1 + d1*grad2)%R with (d1 * (grad1 + grad2))%R in IHdf1 by lra.
- rewrite eqq2, eqq4, eqq6 in IHdf1.
- unfold lift, lift2 in IHdf1; inversion IHdf1.
- cut_to IHdf2; simpler2.
- simpl in IHdf2.
- replace (d0 *grad1 + d0 *grad2)%R with (d0 *(grad1 + grad2))%R in IHdf2 by lra.
- rewrite eqq3,eqq5,eqq7 in IHdf2.
- unfold lift, lift2 in IHdf2; inversion IHdf2.
- f_equal.
- rewrite (split_subvar env env0 d val1); trivial.
- rewrite (split_subvar env1 env2 val val2); trivial.
- rewrite (split_subvar env3 env4 val0 val3); trivial.
- rewrite H7, H8; lra.
- - Case "Divide"%string.
- destruct H2.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df1); intros.
- specialize (H4 H2).
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H5 H3).
- match_option; [|tauto].
- unfold lift, lift2.
- repeat simpl_closed_backprop.
- simpler2.
- specialize (IHdf1 (grad1/d1)%R (grad2/d1)%R grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2 (-d0/(d1*d1) * grad1)%R (-d0/(d1*d1) * grad2)%R env1 env3 env).
- rewrite eqq, eqq8, eqq9 in IHdf1.
- simpl in IHdf1.
- replace (grad1/d1 + grad2/d1)%R with ((grad1 + grad2)/d1)%R in IHdf1 by lra.
- rewrite eqq2, eqq4, eqq6 in IHdf1.
- unfold lift, lift2 in IHdf1; inversion IHdf1.
- cut_to IHdf2; simpler2.
- simpl in IHdf2.
- replace (-d0/(d1*d1) *grad1 + -d0/(d1*d1) *grad2)%R with (-d0/(d1*d1) *(grad1 + grad2))%R in IHdf2 by lra.
- rewrite eqq3, eqq5, eqq7 in IHdf2.
- unfold lift, lift2 in IHdf2; inversion IHdf2.
- f_equal.
- rewrite (split_subvar env env0 d val1); trivial.
- rewrite (split_subvar env1 env2 val val2); trivial.
- rewrite (split_subvar env3 env4 val0 val3); trivial.
- rewrite H7, H8; lra.
- - Case "Square"%string.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df); intros.
- specialize (H3 H2).
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (2 * d0 *grad1)%R (2 * d0 *grad2)%R grad_env1 grad_env2 grad_env3).
- specialize (IHdf H H0 H1 H2).
- rewrite eqq, eqq1, eqq2 in IHdf.
- simpl in IHdf.
- replace (2 * d0 * grad1 + 2 * d0 * grad2)%R with (2 * d0 * (grad1 + grad2))%R in IHdf by lra.
- apply IHdf.
- - Case "Exp"%string.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df); intros.
- specialize (H3 H2).
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (grad1 * exp d0)%R (grad2*exp d0)%R grad_env1 grad_env2 grad_env3).
- specialize (IHdf H H0 H1 H2).
- rewrite eqq, eqq1, eqq2 in IHdf.
- simpl in IHdf.
- replace (grad1 * exp d0 + grad2 * exp d0)%R with ((grad1 + grad2) * exp d0)%R in IHdf by lra.
- apply IHdf.
- - Case "Log"%string.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df); intros.
- specialize (H3 H2).
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (grad1/d0)%R (grad2/d0)%R grad_env1 grad_env2 grad_env3).
- specialize (IHdf H H0 H1 H2).
- rewrite eqq, eqq1, eqq2 in IHdf.
- simpl in IHdf.
- replace (grad1/d0 + grad2/d0)%R with ((grad1 + grad2) / d0)%R in IHdf by lra.
- apply IHdf.
- - Case "Abs"%string.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df); intros.
- specialize (H3 H2).
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (grad1*sign d0)%R (grad2*sign d0)%R grad_env1 grad_env2 grad_env3).
- specialize (IHdf H H0 H1 H2).
- rewrite eqq, eqq1, eqq2 in IHdf.
- simpl in IHdf.
- replace (grad1*sign d0 + grad2*sign d0)%R with ((grad1 + grad2) * sign d0)%R in IHdf by lra.
- apply IHdf.
- - Case "Sign"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (0)%R (0)%R grad_env1 grad_env2 grad_env3).
- specialize (IHdf H H0 H1 H2).
- rewrite eqq, eqq0, eqq1 in IHdf.
- simpl in IHdf.
- replace (0 + 0)%R with 0%R in IHdf by lra.
- apply IHdf.
- - Case "PSign"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (0)%R (0)%R grad_env1 grad_env2 grad_env3).
- specialize (IHdf H H0 H1 H2).
- rewrite eqq, eqq0, eqq1 in IHdf.
- simpl in IHdf.
- replace (0 + 0)%R with 0%R in IHdf by lra.
- apply IHdf.
- - Case "Max"%string.
- destruct H2.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df1); intros.
- specialize (H4 H2).
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H5 H3).
- match_option; [|tauto].
- destruct (Rle_dec d0 d1).
- + match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf2 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H3).
- rewrite eqq, eqq2, eqq3 in IHdf2; simpl in IHdf2.
- apply IHdf2.
- + match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- rewrite eqq, eqq2, eqq3 in IHdf1; simpl in IHdf1.
- apply IHdf1.
- - Case "VectorDot"%string.
- destruct H2.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df1); intros.
- specialize (H4 H2).
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H5 H3).
- match_option; [|tauto].
- unfold lift, lift2.
- repeat simpl_closed_backprop.
- simpler2.
- specialize (IHdf1
- (vmap (fun rv : R => (rv * grad1)%R) d1)
- (vmap (fun rv : R => (rv * grad2)%R) d1)
- grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2
- (vmap (fun lv : R => (lv * grad1)%R) d0)
- (vmap (fun lv : R => (lv * grad2)%R) d0)
- env1 env3 env).
- rewrite eqq, eqq8, eqq9 in IHdf1.
- simpl in IHdf1.
- replace
- (fun i : {n' : nat | n' < n} =>
- (vmap (fun rv : R => rv * grad1) d1 i + vmap (fun rv : R => rv * grad2) d1 i)%R)
- with
- (vmap (fun rv : R => (rv * (grad1 + grad2))%R) d1)
- in IHdf1.
- + rewrite eqq2, eqq4, eqq6 in IHdf1.
- unfold lift, lift2 in IHdf1; inversion IHdf1.
- simpl in IHdf2.
- cut_to IHdf2; simpler2.
- replace (fun i : {n' : nat | n' < n} =>
- (vmap (fun lv : R => lv * grad1) d0 i +
- vmap (fun lv : R => lv * grad2) d0 i)%R) with
- (vmap (fun lv : R => (lv * (grad1 + grad2))%R) d0) in IHdf2.
- * rewrite eqq3, eqq5, eqq7 in IHdf2.
- unfold lift, lift2 in IHdf2; inversion IHdf2.
- f_equal.
- rewrite (split_subvar env env0 d val1); trivial.
- rewrite (split_subvar env1 env2 val val2); trivial.
- rewrite (split_subvar env3 env4 val0 val3); trivial.
- rewrite H7, H8; lra.
- * apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vmap_nth.
- rewrite vmap_nth.
- rewrite vmap_nth.
- lra.
- + apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vmap_nth.
- rewrite vmap_nth.
- rewrite vmap_nth.
- lra.
- - Case "VectorSum"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (ConstVector n grad1) (ConstVector n grad2) grad_env1 grad_env2 grad_env3).
- specialize (IHdf H H0 H1 H2).
- rewrite eqq, eqq0, eqq1 in IHdf.
- simpl in IHdf.
- unfold ConstVector in IHdf.
- unfold ConstVector.
- apply IHdf.
- - Case "MatrixSum"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (ConstMatrix m n grad1) (ConstMatrix m n grad2) grad_env1 grad_env2 grad_env3).
- specialize (IHdf H H0 H1 H2).
- rewrite eqq, eqq0, eqq1 in IHdf.
- simpl in IHdf.
- unfold ConstMatrix in IHdf.
- unfold ConstMatrix.
- apply IHdf.
- - Case "VectorElem"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (fun k : {n' : nat | n' < n} =>
- if equiv_dec (` k) (` i) then grad1 else 0%R)
- (fun k : {n' : nat | n' < n} =>
- if equiv_dec (` k) (` i) then grad2 else 0%R)
- grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- rewrite eqq, eqq0, eqq1 in IHdf.
- simpl in IHdf.
- replace (fun k : {n' : nat | n' < n} =>
- if equiv_dec (` k) (` i) then (grad1 + grad2)%R else 0%R) with
- (fun i0 : {n' : nat | n' < n} =>
- ((if equiv_dec (` i0) (` i) then grad1 else 0) +
- (if equiv_dec (` i0) (` i) then grad2 else 0))%R).
- apply IHdf.
- apply FunctionalExtensionality.functional_extensionality; intros.
- destruct (equiv_dec (` x) (` i)); lra.
- - Case "MatrixElem"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf
- (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) =>
- if equiv_dec (` k1) (` i)
- then if equiv_dec (` k2) (` j) then grad1 else 0%R
- else 0%R)
- (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) =>
- if equiv_dec (` k1) (` i)
- then if equiv_dec (` k2) (` j) then grad2 else 0%R
- else 0%R)
- grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- rewrite eqq, eqq0, eqq1 in IHdf.
- simpl in IHdf.
- replace (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) =>
- if equiv_dec (` k1) (` i)
- then if equiv_dec (` k2) (` j) then (grad1 + grad2)%R else 0%R
- else 0%R) with
- (fun (i0 : {n' : nat | n' < m}) (j0 : {m' : nat | m' < n}) =>
- ((if equiv_dec (` i0) (` i)
- then if equiv_dec (` j0) (` j) then grad1 else 0
- else 0) +
- (if equiv_dec (` i0) (` i)
- then if equiv_dec (` j0) (` j) then grad2 else 0
- else 0))%R).
- apply IHdf.
- apply FunctionalExtensionality.functional_extensionality; intros.
- destruct (equiv_dec (` x) (` i)).
- + apply FunctionalExtensionality.functional_extensionality; intros.
- destruct (equiv_dec (` x0) (` j)); lra.
- + apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- - Case "MatrixVectorMult"%string.
- destruct H2.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df1); intros.
- specialize (H4 H2).
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H5 H3).
- match_option; [|tauto].
- unfold lift, lift2.
- repeat simpl_closed_backprop.
- simpler2.
- specialize (IHdf1
- (fun (i : {n' : nat | n' < m})
- (j : {m' : nat | m' < n}) => (grad1 i * d1 j)%R)
- (fun (i : {n' : nat | n' < m})
- (j : {m' : nat | m' < n}) => (grad2 i * d1 j)%R)
- grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- rewrite eqq, eqq8, eqq9 in IHdf1.
- specialize (IHdf2
- (matrix_vector_mult
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => d0 j i) grad1)
- (matrix_vector_mult
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => d0 j i) grad2)
- env1 env3 env).
- simpl in IHdf1.
- replace
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (grad1 i * d1 j + grad2 i * d1 j)%R) with
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- ((grad1 i + grad2 i) * d1 j)%R) in IHdf1.
- + rewrite eqq2, eqq4, eqq6 in IHdf1.
- unfold lift, lift2 in IHdf1; inversion IHdf1.
- cut_to IHdf2; simpler2.
- simpl in IHdf2.
- replace
- (fun i : {n' : nat | n' < n} =>
- (matrix_vector_mult
- (fun (i0 : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) =>
- d0 j i0) grad1 i +
- matrix_vector_mult
- (fun (i0 : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) =>
- d0 j i0) grad2 i)%R) with
- (matrix_vector_mult
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => d0 j i)
- (fun i : {n' : nat | n' < m} => (grad1 i + grad2 i)%R)) in IHdf2.
- * rewrite eqq3, eqq5, eqq7 in IHdf2.
- unfold lift, lift2 in IHdf2; inversion IHdf2.
- f_equal.
- rewrite (split_subvar env env0 d val1); trivial.
- rewrite (split_subvar env1 env2 val val2); trivial.
- rewrite (split_subvar env3 env4 val0 val3); trivial.
- rewrite H7, H8; lra.
- * apply FunctionalExtensionality.functional_extensionality; intros.
- unfold matrix_vector_mult.
- rewrite vsum_plus; f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- simpl; lra.
- + apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- - Case "MatrixVectorAdd"%string.
- destruct H2.
- simpl; intros.
- repeat simpl_closed_backprop.
- simpler2.
- specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- simpl in IHdf1.
- rewrite eqq,eqq0,eqq1,eqq2,eqq3,eqq4 in IHdf1.
- unfold lift, lift2 in IHdf1.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) H); intro.
- case_eq (vartlookup env0 (s, DTfloat)); [intros|tauto].
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat) H0); intro.
- case_eq (vartlookup env1 (s, DTfloat)); [intros|tauto].
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq (s, DTfloat) H1); intro.
- case_eq (vartlookup env (s, DTfloat)); [intros|tauto].
- generalize (backprop_mat_grad_sum_list_env_iter
- σ df2 s env0 env1 env (transpose grad1) (transpose grad2)
- d d0 d1 (bounded_seq0 n)); intros.
- simpl in H10.
- specialize (H10 H3).
- cut_to H10; trivial.
- + do 3 match_option.
- * rewrite eqq6, eqq7 in H10.
- replace (fun (j : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env
- (fun k : {n' : nat | n' < m} => ((@transpose R m n grad1 j k) +
- (@transpose R m n grad2 j k))%R))
- with
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env
- (transpose
- (fun (i0 : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (grad1 i0 j + grad2 i0 j)%R) i)) in H10.
- -- rewrite eqq5 in H10.
- unfold lift, lift2 in H10.
- unfold lift, lift2; f_equal.
- inversion IHdf1; inversion H10.
- rewrite (split_subvar env d2 val d1); trivial.
- rewrite (split_subvar env0 d3 val0 d); trivial.
- rewrite (split_subvar env1 d4 val1 d0); trivial.
- rewrite H12, H13; lra.
- -- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- f_equal.
- * assert (list_env_iter
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad2 i))
- (Some env1) (bounded_seq0 n) <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- * assert (list_env_iter
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad1 i))
- (Some env0) (bounded_seq0 n) <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- * assert (list_env_iter
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad1 i))
- (Some env0) (bounded_seq0 n) <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- * assert (list_env_iter
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env
- (transpose
- (fun (i0 : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (grad1 i0 j + grad2 i0 j)%R) i)) (Some env) (bounded_seq0 n) <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- + intros.
- apply IHdf2; trivial.
- - Case "MatrixMult"%string.
- destruct H2.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df1); intros.
- specialize (H4 H2).
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H5 H3).
- match_option; [|tauto].
- unfold lift, lift2.
- repeat simpl_closed_backprop.
- simpler2.
- specialize
- (IHdf1
- (matrix_mult grad1
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < p}) => d1 j i))
- (matrix_mult grad2
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < p}) => d1 j i))
- grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- rewrite eqq, eqq8, eqq9 in IHdf1.
- specialize (IHdf2
- (matrix_mult (fun (i : {n' : nat | n' < p})
- (j : {m' : nat | m' < m}) => d0 j i)
- grad1)
- (matrix_mult (fun (i : {n' : nat | n' < p})
- (j : {m' : nat | m' < m}) => d0 j i)
- grad2)
- env1 env3 env).
- simpl in IHdf1.
- replace
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < p}) =>
- (matrix_mult grad1
- (fun (i0 : {n' : nat | (n' < n)%nat}) (j0 : {m' : nat | (m' < p)%nat}) =>
- d1 j0 i0) i j +
- matrix_mult grad2
- (fun (i0 : {n' : nat | (n' < n)%nat}) (j0 : {m' : nat | (m' < p)%nat}) =>
- d1 j0 i0) i j)%R)
- with
- (matrix_mult
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (grad1 i j + grad2 i j)%R)
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < p}) => d1 j i)) in IHdf1.
- + rewrite eqq2, eqq4, eqq6 in IHdf1.
- unfold lift, lift2 in IHdf1; simpl in IHdf1.
- inversion IHdf1.
- cut_to IHdf2; simpler2.
- simpl in IHdf2.
- rewrite eqq5, eqq7 in IHdf2.
- replace
- (fun (i : {n' : nat | n' < p}) (j : {m' : nat | m' < n}) =>
- (matrix_mult
- (fun (i0 : {n' : nat | (n' < p)%nat}) (j0 : {m' : nat | (m' < m)%nat}) =>
- d0 j0 i0) grad1 i j +
- matrix_mult
- (fun (i0 : {n' : nat | (n' < p)%nat}) (j0 : {m' : nat | (m' < m)%nat}) =>
- d0 j0 i0) grad2 i j)%R) with
- (matrix_mult (fun (i : {n' : nat | n' < p}) (j : {m' : nat | m' < m}) => d0 j i)
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (grad1 i j + grad2 i j)%R))
- in IHdf2.
- * rewrite eqq3 in IHdf2.
- unfold lift, lift2 in IHdf2; simpl in IHdf2.
- inversion IHdf2.
- f_equal.
- rewrite (split_subvar env env0 d val1); trivial.
- rewrite (split_subvar env1 env2 val val2); trivial.
- rewrite (split_subvar env3 env4 val0 val3); trivial.
- rewrite H7, H8; lra.
- * unfold matrix_mult.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vsum_plus; f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- simpl; lra.
- + unfold matrix_mult.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vsum_plus; f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- simpl; lra.
- - Case "VectorPlus"%string.
- destruct H2.
- unfold lift, lift2.
- repeat simpl_closed_backprop.
- specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2 grad1 grad2 env1 env3 env).
- simpler2.
- simpl in IHdf1.
- rewrite eqq1, eqq3, eqq in IHdf1.
- unfold lift, lift2 in IHdf1; simpl in IHdf1.
- inversion IHdf1.
- cut_to IHdf2; try congruence.
- simpl in IHdf2.
- rewrite eqq0, eqq2, eqq4 in IHdf2.
- unfold lift, lift2 in IHdf2; inversion IHdf2.
- f_equal.
- rewrite (split_subvar env env0 val val5); trivial.
- rewrite (split_subvar env1 env2 val0 val6); trivial.
- rewrite (split_subvar env3 env4 val1 val7); trivial.
- rewrite eqq5 in eqq8.
- rewrite eqq6 in eqq9.
- rewrite eqq7 in eqq10.
- invcs eqq8; invcs eqq9; invcs eqq10.
- rewrite H5, H6; lra.
- - Case "VectorMinus"%string.
- destruct H2.
- unfold lift, lift2.
- repeat simpl_closed_backprop.
- specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2 (fun i : {n' : nat | n' < n} => (- grad1 i)%R)
- (fun i : {n' : nat | n' < n} => (- grad2 i)%R)
- env1 env3 env).
- simpler2.
- simpl in IHdf1.
- rewrite eqq, eqq1, eqq3 in IHdf1.
- unfold lift, lift2 in IHdf1; simpl in IHdf1.
- inversion IHdf1.
- cut_to IHdf2; try congruence.
- simpl in IHdf2.
- replace (fun i : {n' : nat | n' < n} => (- grad1 i + - grad2 i)%R) with
- (fun i : {n' : nat | n' < n} => (- (grad1 i + grad2 i))%R) in IHdf2.
- rewrite eqq0, eqq2, eqq4 in IHdf2.
- unfold lift, lift2 in IHdf2; simpl in IHdf2.
- inversion IHdf2.
- f_equal.
- rewrite (split_subvar env env0 val val5); trivial.
- rewrite (split_subvar env1 env2 val0 val6); trivial.
- rewrite (split_subvar env3 env4 val1 val7); trivial.
- rewrite eqq5 in eqq8; rewrite eqq6 in eqq9; rewrite eqq7 in eqq10.
- invcs eqq8; invcs eqq9; invcs eqq10.
- rewrite H5, H6; lra.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- - Case "MatrixPlus"%string.
- destruct H2.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv
- σ df1 grad_env3
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (grad1 i j + grad2 i j)%R)
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env1 (grad1)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2 (grad2)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2 grad1 grad2 d2 d4 d0).
- rewrite eqq, eqq1, eqq3 in IHdf1.
- simpl in IHdf1.
- rewrite eqq0, eqq2, eqq4 in IHdf1.
- unfold lift, lift2 in IHdf1; simpl in IHdf1.
- inversion IHdf1.
- assert (vartlookup d2 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) H).
- assert (vartlookup d4 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s, DTfloat) H0).
- assert (vartlookup d0 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) H1).
- specialize (IHdf2 H7 H9 H10 H3).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- simpl in IHdf2.
- unfold lift, lift2; simpl.
- assert (df_eval_backprop_deriv
- σ df2 d0
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad1 i j + grad2 i j)%R) <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d2 (grad1)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- case_eq (df_eval_backprop_deriv σ df2 d2 (grad1)%R); [intros|tauto].
- assert (df_eval_backprop_deriv σ df2 d4 (grad2)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- case_eq (df_eval_backprop_deriv σ df2 d4 (grad2)%R); [intros|tauto].
- simpl in IHdf2.
- rewrite eqq8, H13, H15 in IHdf2.
- unfold lift, lift2 in IHdf2; simpl in IHdf2.
- inversion IHdf2.
- f_equal.
- rewrite (split_subvar d0 d8 d d5); trivial.
- rewrite (split_subvar d2 d9 d1 d6); trivial.
- rewrite (split_subvar d4 d10 d3 d7); trivial.
- rewrite H8, H17; lra.
- - Case "MatrixMinus"%string.
- destruct H2.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv
- σ df1 grad_env3
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad1 i j + grad2 i j)%R) <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env1 (grad1)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2 (grad2)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (- grad1 i j)%R)
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (- grad2 i j)%R)
- d2 d4 d0).
- rewrite eqq, eqq1, eqq3 in IHdf1.
- simpl in IHdf1.
- rewrite eqq0, eqq2, eqq4 in IHdf1.
- unfold lift, lift2 in IHdf1; simpl in IHdf1.
- inversion IHdf1.
- assert (vartlookup d2 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) H).
- assert (vartlookup d4 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s, DTfloat) H0).
- assert (vartlookup d0 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) H1).
- specialize (IHdf2 H7 H9 H10 H3).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- unfold lift, lift2; simpl.
- assert (df_eval_backprop_deriv
- σ df2 d0
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (- (grad1 i j + grad2 i j))%R) <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv
- σ df2 d2
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad1 i j)%R)
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- case_eq (df_eval_backprop_deriv σ df2 d2
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad1 i j)%R))
- ; [intros |tauto].
- assert (df_eval_backprop_deriv σ df2 d4
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad2 i j)%R)
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- case_eq (df_eval_backprop_deriv σ df2 d4
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad2 i j)%R))
- ; [intros | tauto].
- simpl in IHdf2.
- replace
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (- grad1 i j + - grad2 i j)%R) with
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (- (grad1 i j + grad2 i j))%R) in IHdf2.
- rewrite eqq8, H13, H15 in IHdf2.
- unfold lift, lift2 in IHdf2; simpl in IHdf2.
- inversion IHdf2.
- f_equal.
- rewrite (split_subvar d0 d8 d d5); trivial.
- rewrite (split_subvar d2 d9 d1 d6); trivial.
- rewrite (split_subvar d4 d10 d3 d7); trivial.
- rewrite H8, H17; lra.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- - Case "VectorScalMult"%string.
- destruct H2.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df1); intros.
- specialize (H4 H2).
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H5 H3).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env3
- (vsum (fun j : {n' : nat | n' < n} => (d1 j * (grad1 j + grad2 j))%R))
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env1
- (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad1 j)%R))
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2
- (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad2 j)%R))
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- specialize (IHdf1
- (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad1 j)%R))
- (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad2 j)%R))
- grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2
- (fun j : {n' : nat | n' < n} => (d0 * grad1 j)%R)
- (fun j : {n' : nat | n' < n} => (d0 * grad2 j)%R)
- d4 d6 d2).
- rewrite eqq, eqq3, eqq5 in IHdf1.
- simpl in IHdf1.
- replace
- (@vsum floatish_R n (fun j : {n' : nat | (n' < n)%nat} => d1 j * grad1 j) +
- @vsum floatish_R n (fun j : {n' : nat | (n' < n)%nat} => d1 j * grad2 j))%R with
- (vsum (fun j : {n' : nat | n' < n} => (d1 j * (grad1 j + grad2 j))%R)) in IHdf1.
- + rewrite eqq2, eqq4, eqq6 in IHdf1.
- unfold lift, lift2 in IHdf1; simpl in IHdf1.
- inversion IHdf1.
- assert (vartlookup d2 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) H1).
- assert (vartlookup d4 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s, DTfloat) H).
- assert (vartlookup d6 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq6 (s, DTfloat) H0).
- specialize (IHdf2 H11 H12 H9 H3).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- unfold lift, lift2; simpl.
- assert (df_eval_backprop_deriv
- σ df2 d2
- (fun j : {n' : nat | n' < n} => (d0 * (grad1 j + grad2 j))%R)
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d4
- (fun j : {n' : nat | n' < n} => (d0 * grad1 j)%R)
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- case_eq (df_eval_backprop_deriv σ df2 d4
- (fun j : {n' : nat | n' < n} => (d0 * grad1 j)%R))
- ; [intros | tauto].
- assert (df_eval_backprop_deriv σ df2 d6
- (fun j : {n' : nat | n' < n} => (d0 * grad2 j)%R)
- <> None ).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- case_eq (df_eval_backprop_deriv σ df2 d6
- (fun j : {n' : nat | n' < n} => (d0 * grad2 j)%R))
- ; [intros | tauto].
- simpl in IHdf2.
- replace
- (fun i : {n' : nat | n' < n} => (d0 * grad1 i + d0 * grad2 i)%R) with
- (fun j : {n' : nat | n' < n} => (d0 * (grad1 j + grad2 j))%R) in IHdf2.
- * rewrite eqq10, H15, H17 in IHdf2.
- unfold lift, lift2 in IHdf2; simpl in IHdf2.
- inversion IHdf2.
- f_equal.
- rewrite (split_subvar d2 d10 d d7); trivial.
- rewrite (split_subvar d4 d11 d3 d8); trivial.
- rewrite (split_subvar d6 d12 d5 d9); trivial.
- rewrite H10, H19; lra.
- * apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- + simpl.
- rewrite vsum_plus.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- - Case "MatrixScalMult"%string.
- destruct H2.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df1); intros.
- specialize (H4 H2).
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H5 H3).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv
- σ df1 grad_env3
- (msum
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (d1 i j * (grad1 i j + grad2 i j))%R))
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env1
- (msum
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (d1 i j * grad1 i j)%R))
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2
- (msum
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (d1 i j * grad2 i j)%R))
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- specialize (IHdf1
- (msum
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (d1 i j * grad1 i j)%R))
- (msum
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (d1 i j * grad2 i j)%R))
- grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (grad1 i j * d0)%R)
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (grad2 i j * d0)%R)
- d4 d6 d2).
- rewrite eqq, eqq3, eqq5 in IHdf1.
- simpl in IHdf1.
- replace
- (@msum floatish_R n m
- (fun (i : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) =>
- d1 i j * grad1 i j) +
- @msum floatish_R n m
- (fun (i : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) =>
- d1 i j * grad2 i j))%R with
- (msum
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (d1 i j * (grad1 i j + grad2 i j))%R)) in IHdf1.
- + rewrite eqq2, eqq4, eqq6 in IHdf1.
- unfold lift, lift2 in IHdf1; simpl in IHdf1.
- inversion IHdf1.
- assert (vartlookup d2 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) H1).
- assert (vartlookup d4 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s, DTfloat) H).
- assert (vartlookup d6 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq6 (s, DTfloat) H0).
- specialize (IHdf2 H11 H12 H9 H3).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- unfold lift, lift2; simpl.
- assert (df_eval_backprop_deriv
- σ df2 d2
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- ((grad1 i j + grad2 i j) * d0)%R)
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d4
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad1 i j * d0)%R)
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- case_eq (df_eval_backprop_deriv σ df2 d4
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad1 i j * d0)%R))
- ; [intros | tauto].
- assert (df_eval_backprop_deriv σ df2 d6
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad2 i j * d0)%R)
- <> None ).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- case_eq (df_eval_backprop_deriv σ df2 d6
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad2 i j * d0)%R))
- ; [intros | tauto].
- simpl in IHdf2.
- replace
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (grad1 i j * d0 + grad2 i j * d0)%R) with
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- ((grad1 i j + grad2 i j) * d0)%R) in IHdf2.
- * rewrite eqq10, H15, H17 in IHdf2.
- unfold lift, lift2 in IHdf2; simpl in IHdf2.
- inversion IHdf2.
- f_equal.
- rewrite (split_subvar d2 d10 d d7); trivial.
- rewrite (split_subvar d4 d11 d3 d8); trivial.
- rewrite (split_subvar d6 d12 d5 d9); trivial.
- rewrite H10, H19; lra.
- * apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- + unfold msum.
- rewrite vsum_plus.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vmap_nth.
- rewrite vmap_nth.
- rewrite vmap_nth.
- rewrite vsum_plus.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- - Case "VectorApply"%string.
- destruct H2.
- simpler2.
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H4 H3).
- match_option.
- match_option.
- + match_option.
- * match_option.
- -- unfold lift.
- repeat simpl_closed_backprop.
- unfold lift2.
- specialize (apply vectoro_to_ovector_forall_some_f eqq3);intros.
- specialize (apply vectoro_to_ovector_forall_some_f eqq4);intros.
- specialize (apply vectoro_to_ovector_forall_some_f eqq5);intros.
- specialize (IHdf2 v1 v2 grad_env1 grad_env2 grad_env3).
- cut_to IHdf2; try congruence.
- rewrite eqq, eqq0, eqq1, eqq7, eqq8 in IHdf2; simpl in IHdf2.
- replace (fun i : {n' : nat | n' < n} => (v1 i + v2 i)%R) with v0 in IHdf2.
- rewrite eqq6 in IHdf2.
- unfold lift, lift2 in IHdf2.
- apply IHdf2.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H5 x);rewrite vmap_nth in H5; simpl in H5.
- specialize (H6 x);rewrite vmap_nth in H6; simpl in H6.
- specialize (H7 x);rewrite vmap_nth in H7; simpl in H7.
- match_option_in H5; invcs H5.
- match_option_in H6; invcs H6.
- match_option_in H7; invcs H7.
- rewrite eqq9 in eqq10; invcs eqq10.
- rewrite eqq9 in eqq11; invcs eqq11.
- lra.
- -- specialize (apply vectoro_to_ovector_exists_None eqq5); intros.
- destruct H5.
- rewrite vmap_nth in e; simpl in e.
- match_option_in e.
- assert (df_eval [mk_env_entry (v, DTfloat) (d x)] (df_deriv df1 (v, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- * specialize (apply vectoro_to_ovector_exists_None eqq4); intros.
- destruct H5.
- rewrite vmap_nth in e; simpl in e.
- match_option_in e.
- assert (df_eval [mk_env_entry (v, DTfloat) (d x)] (df_deriv df1 (v, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- + specialize (apply vectoro_to_ovector_exists_None eqq3); intros.
- destruct H5.
- rewrite vmap_nth in e; simpl in e.
- match_option_in e.
- assert (df_eval [mk_env_entry (v, DTfloat) (d x)] (df_deriv df1 (v, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- - Case "MatrixApply"%string.
- destruct H2.
- simpler2.
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H4 H3).
- match_option.
- match_option.
- + match_option.
- * match_option.
- -- unfold lift.
- repeat simpl_closed_backprop.
- unfold lift2.
- unfold matrixo_to_omatrix in *.
- specialize (apply vectoro_to_ovector_forall_some_f eqq3);intros.
- specialize (apply vectoro_to_ovector_forall_some_f eqq4);intros.
- specialize (apply vectoro_to_ovector_forall_some_f eqq5);intros.
- specialize (IHdf2 m1 m2 grad_env1 grad_env2 grad_env3).
- cut_to IHdf2; try congruence.
- rewrite eqq, eqq0, eqq1, eqq7, eqq8 in IHdf2; simpl in IHdf2.
- replace (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (m1 i j + m2 i j)%R) with m0 in IHdf2.
- rewrite eqq6 in IHdf2.
- unfold lift, lift2 in IHdf2.
- apply IHdf2.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H5 x); specialize (H6 x); specialize (H7 x).
- unfold mmap in H5; unfold mmap in H6; unfold mmap in H7.
- specialize (apply vectoro_to_ovector_forall_some_f H5);intros.
- specialize (apply vectoro_to_ovector_forall_some_f H6);intros.
- specialize (apply vectoro_to_ovector_forall_some_f H7);intros.
- specialize (H8 x0); do 2 rewrite vmap_nth in H8.
- specialize (H9 x0); do 2 rewrite vmap_nth in H9.
- specialize (H10 x0); do 2 rewrite vmap_nth in H10.
- rewrite matrix_zip_m_n in H8.
- rewrite matrix_zip_m_n in H9.
- rewrite matrix_zip_m_n in H10.
- match_option_in H8.
- rewrite eqq9 in H9; rewrite eqq9 in H10.
- invcs H8; invcs H9; invcs H10.
- lra.
- -- specialize (apply vectoro_to_ovector_exists_None eqq5); intros; destruct H5.
- specialize (apply vectoro_to_ovector_exists_None e); intros; destruct H5.
- unfold mmap in e0.
- do 2 rewrite vmap_nth in e0; simpl in e0.
- rewrite matrix_zip_m_n in e0.
- match_option_in e0.
- assert (df_eval [mk_env_entry (v, DTfloat) (d x x0)] (df_deriv df1 (v, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- * specialize (apply vectoro_to_ovector_exists_None eqq4); intros; destruct H5.
- specialize (apply vectoro_to_ovector_exists_None e); intros; destruct H5.
- unfold mmap in e0.
- do 2 rewrite vmap_nth in e0; simpl in e0.
- rewrite matrix_zip_m_n in e0.
- match_option_in e0.
- assert (df_eval [mk_env_entry (v, DTfloat) (d x x0)] (df_deriv df1 (v, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- + specialize (apply vectoro_to_ovector_exists_None eqq3); intros; destruct H5.
- specialize (apply vectoro_to_ovector_exists_None e); intros; destruct H5.
- unfold mmap in e0.
- do 2 rewrite vmap_nth in e0; simpl in e0.
- rewrite matrix_zip_m_n in e0.
- match_option_in e0.
- assert (df_eval [mk_env_entry (v, DTfloat) (d x x0)] (df_deriv df1 (v, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- - Case "VLossfun"%string.
- destruct H2.
- simpler2.
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H4 H3).
- match_option.
- match_option.
- + match_option.
- * match_option.
- -- unfold lift.
- repeat simpl_closed_backprop.
- unfold lift2.
- specialize (apply vectoro_to_ovector_forall_some_f eqq3);intros.
- specialize (apply vectoro_to_ovector_forall_some_f eqq4);intros.
- specialize (apply vectoro_to_ovector_forall_some_f eqq5);intros.
- specialize (IHdf2 v0 v3 grad_env1 grad_env2 grad_env3).
- cut_to IHdf2; try congruence.
- rewrite eqq, eqq0, eqq1 in IHdf2.
- rewrite eqq7, eqq8 in IHdf2.
- simpl in IHdf2.
- replace (fun i : {n' : nat | n' < n} => (v0 i + v3 i)%R) with v in IHdf2.
- rewrite eqq6 in IHdf2.
- unfold lift, lift2 in IHdf2.
- apply IHdf2.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H5 x);rewrite vmap_nth in H5; simpl in H5.
- specialize (H6 x);rewrite vmap_nth in H6; simpl in H6.
- specialize (H7 x);rewrite vmap_nth in H7; simpl in H7.
- match_option_in H5.
- rewrite eqq9 in H6; rewrite eqq9 in H7.
- invcs H5; invcs H6; invcs H7.
- lra.
- -- specialize (apply vectoro_to_ovector_exists_None eqq5); intros.
- destruct H5.
- rewrite vmap_nth in e; simpl in e.
- match_option_in e.
- assert ( df_eval [mk_env_entry (v1, DTfloat) (d x);
- mk_env_entry (v2, DTfloat) (r x)]
- (df_deriv df1 (v1, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- * specialize (apply vectoro_to_ovector_exists_None eqq4); intros.
- destruct H5.
- rewrite vmap_nth in e; simpl in e.
- match_option_in e.
- assert (df_eval [mk_env_entry (v1, DTfloat) (d x);
- mk_env_entry (v2, DTfloat) (r x)]
- (df_deriv df1 (v1, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- + specialize (apply vectoro_to_ovector_exists_None eqq3); intros.
- destruct H5.
- rewrite vmap_nth in e; simpl in e.
- match_option_in e.
- assert (df_eval [mk_env_entry (v1, DTfloat) (d x); mk_env_entry (v2, DTfloat) (r x)]
- (df_deriv df1 (v1, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- - Case "MLossfun"%string.
- destruct H2.
- simpler2.
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H4 H3).
- match_option.
- match_option.
- + match_option.
- * match_option.
- -- unfold lift.
- repeat simpl_closed_backprop.
- unfold lift2.
- specialize (apply vectoro_to_ovector_forall_some_f eqq3);intros.
- specialize (apply vectoro_to_ovector_forall_some_f eqq4);intros.
- specialize (apply vectoro_to_ovector_forall_some_f eqq5);intros.
- specialize (IHdf2 m1 m2 grad_env1 grad_env2 grad_env3).
- cut_to IHdf2; try congruence.
- rewrite eqq, eqq0, eqq1 in IHdf2.
- rewrite eqq7, eqq8 in IHdf2.
- simpl in IHdf2.
- replace (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (m1 i j + m2 i j)%R) with m0 in IHdf2.
- rewrite eqq6 in IHdf2.
- unfold lift, lift2 in IHdf2.
- apply IHdf2.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H5 x); simpl in H5.
- specialize (H6 x); simpl in H6.
- specialize (H7 x); simpl in H7.
- specialize (apply vectoro_to_ovector_forall_some_f H5);intros.
- specialize (apply vectoro_to_ovector_forall_some_f H6);intros.
- specialize (apply vectoro_to_ovector_forall_some_f H7);intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H8 x0); unfold mmap in H8;do 2 rewrite vmap_nth in H8; simpl in H8.
- specialize (H9 x0); unfold mmap in H9;do 2 rewrite vmap_nth in H9; simpl in H9.
- specialize (H10 x0); unfold mmap in H10;do 2 rewrite vmap_nth in H10; simpl in H10.
- rewrite matrix_zip_m_n in H8.
- rewrite matrix_zip_m_n in H9.
- rewrite matrix_zip_m_n in H10.
- match_option_in H8.
- rewrite eqq9 in H9; rewrite eqq9 in H10.
- invcs H8; invcs H9; invcs H10.
- lra.
- -- specialize (apply vectoro_to_ovector_exists_None eqq5); intros.
- destruct H5.
- specialize (apply vectoro_to_ovector_exists_None e); intros.
- destruct H5.
- unfold mmap in e0; do 2 rewrite vmap_nth in e0; simpl in e0.
- rewrite matrix_zip_m_n in e0.
- match_option_in e0.
- assert ( df_eval [mk_env_entry (v1, DTfloat) (d x x0);
- mk_env_entry (v2, DTfloat) (r x x0)]
- (df_deriv df1 (v1, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- * specialize (apply vectoro_to_ovector_exists_None eqq4); intros.
- destruct H5.
- specialize (apply vectoro_to_ovector_exists_None e); intros.
- destruct H5.
- unfold mmap in e0; do 2 rewrite vmap_nth in e0; simpl in e0.
- rewrite matrix_zip_m_n in e0.
- match_option_in e0.
- assert (df_eval [mk_env_entry (v1, DTfloat) (d x x0);
- mk_env_entry (v2, DTfloat) (r x x0)]
- (df_deriv df1 (v1, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- + specialize (apply vectoro_to_ovector_exists_None eqq3); intros.
- destruct H5.
- specialize (apply vectoro_to_ovector_exists_None e); intros.
- destruct H5.
- unfold mmap in e0; do 2 rewrite vmap_nth in e0; simpl in e0.
- rewrite matrix_zip_m_n in e0.
- match_option_in e0.
- assert (df_eval [mk_env_entry (v1, DTfloat) (d x x0);
- mk_env_entry (v2, DTfloat) (r x x0)]
- (df_deriv df1 (v1, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- Qed.
-
- Lemma vectoro_to_ovector_eNone_None {A n} {vo:Vector (option A) n} :
- {i | vo i = None} ->
- vectoro_to_ovector vo = None.
- Proof.
- intros.
- destruct H.
- now apply (vectoro_to_ovector_None_None x).
- Qed.
-
- Definition ConstSplitVector {T} (middle theend:nat) (part1 part2:T) : (Vector T theend) :=
- fun (i: {n':nat | n' < theend}%nat) =>
- if lt_dec (proj1_sig i) middle then part1 else part2.
-
- Definition mergeVectorZero {T} {n} (middle:nat) (part1 : Vector T n) (c:T) : (Vector T n) :=
- fun (i: {n':nat | n' < n}%nat) =>
- if lt_dec (proj1_sig i) middle then (part1 i) else c.
-
- Definition scaleUnitVector {T} (n:nat) (j : {n':nat | (n' < n)%nat}) (c:T) (zero:T) : Vector T n :=
- fun i => if (proj1_sig i) == (proj1_sig j) then c else zero%R.
-
- Lemma ConstSplitVectorSzero bound n (pf:bound < n) :
- (ConstSplitVector (S bound) n 1%R 0%R) =
- dfti_gen_plus (T:=DTVector n) (ConstSplitVector bound n 1%R 0%R) (UnitVector n (exist _ bound pf)).
- Proof.
- simpl.
- apply functional_extensionality.
- intros.
- unfold ConstSplitVector, UnitVector, equiv_dec, nat_eq_eqdec; simpl.
- destruct x as [x pff]; simpl.
- destruct (lt_dec x (S bound))
- ; destruct (lt_dec x bound)
- ; destruct (Nat.eq_dec x bound)
- ; try lia; try lra.
- Qed.
-
- Lemma mergeVectorSzero {n} (bound:nat) (pf:bound < n) (part1 : Vector R n):
- let ind := (exist _ bound pf) in
- (@mergeVectorZero (@float floatish_R) n (S bound) part1 0%R) =
- dfti_gen_plus (T:=DTVector n) (@mergeVectorZero (@float floatish_R) n bound part1 0%R)
- (scaleUnitVector n ind (part1 ind) 0%R).
- Proof.
- simpl.
- apply functional_extensionality.
- intros.
- unfold mergeVectorZero, scaleUnitVector, equiv_dec, nat_eq_eqdec; simpl.
- destruct x as [x pff]; simpl.
- destruct (lt_dec x (S bound))
- ; destruct (lt_dec x bound)
- ; destruct (Nat.eq_dec x bound)
- ; try lia; try lra.
- subst; simpl.
- ring_simplify.
- erewrite index_pf_irrel; eauto.
- Qed.
-
- Lemma mergeVectorSzero_mat {n m} (bound:nat) pf (part1 : Matrix float n m) :
- let ind := (exist _ bound pf) in
- (@mergeVectorZero (Vector float m) n (S bound) part1 (ConstVector m 0%R)) =
- dfti_gen_plus (T:=DTMatrix n m) (@mergeVectorZero (Vector float m) n bound part1 (ConstVector m 0%R))
- (scaleUnitVector (T:=Vector float m) n ind (part1 ind) (ConstVector m 0%R)).
- Proof.
- simpl.
- do 2 (apply functional_extensionality; intros).
- unfold mergeVectorZero, scaleUnitVector, ConstVector, equiv_dec, nat_eq_eqdec; simpl.
- destruct x as [x pff]; simpl.
- destruct (lt_dec x (S bound))
- ; destruct (lt_dec x bound)
- ; destruct (Nat.eq_dec x bound)
- ; simpl
- ; try lia; try lra.
- subst; simpl.
- rewrite Rplus_0_l.
- erewrite index_pf_irrel; eauto.
- Qed.
-
- Lemma vsum_alt_eq {m:nat} (v:Vector R m) : vsum v = vector_fold_right Fplus 0%R v.
- Proof.
- apply vector_fold_right1_as_vector_fold_right.
- unfold Datatypes.id; simpl; intros; lra.
- Qed.
-
- Lemma vsum_cons {m:nat} x (v:Vector R m) :
- vsum (vcons x v) = (x + vsum v)%R.
- Proof.
- repeat rewrite vsum_alt_eq.
- apply vector_fold_right_vcons.
- Qed.
-
- Lemma constSplitVectorZero {n} :
- ConstSplitVector 0 n 1%R 0%R = ConstVector n 0%R.
- Proof.
- unfold ConstSplitVector, ConstVector; simpl.
- now apply functional_extensionality.
- Qed.
-
- Lemma df_eval_backprop_delta_by_unit_partvec_bounded {n} bound (pf:(bound<=n)%nat)
- (σ:df_env) (df:DefinedFunction UnitAnn (DTVector n)) (s: SubVar) grad_env (grad d:Vector float n):
-
- fully_closed_over df
- (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) ->
- vartlookup grad_env (s, DTfloat) <> None ->
-
- (forall i : {n' : nat | n' < n},
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- (scaleUnitVector n i (grad i) 0%R)) = Some (d i)) ->
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- (mergeVectorZero bound grad 0%R)) = Some (vsum (vfirstn d bound pf)).
- Proof.
- intros closed lo fa.
- induction bound.
- - replace (mergeVectorZero 0 grad 0%R) with (scalarMult (DTVector n) 0%R (mergeVectorZero 0 grad 0%R)).
- + erewrite scalarMult_backprop_grad_scalar; try eassumption.
- * simpl.
- unfold df_eval_backprop_delta.
- simpler2.
- unfold lift.
- simpl_closed_backprop.
- f_equal.
- vm_compute; lra.
- * apply backprop_deriv_fully_closed_not_none; trivial.
- * apply backprop_deriv_fully_closed_not_none; trivial.
- + unfold scalarMult, mergeVectorZero.
- apply functional_extensionality; intros; simpl.
- lra.
- - assert (pf2:bound < n) by lia.
- assert (pf3:bound <= n) by lia.
- rewrite (mergeVectorSzero _ pf2 _ ).
- erewrite backprop_grad_sum; try eassumption.
- specialize (IHbound pf3).
- rewrite IHbound.
- rewrite fa.
- simpl.
- f_equal.
- destruct n; [lia | ].
- generalize (vector_Sn_split (vfirstn d (S bound) pf)); intros eqq1.
- apply vec_eq_eq in eqq1.
- simpl in *.
- rewrite eqq1.
- erewrite vfirstn_vdrop_last.
- erewrite vlast_vfirstn.
- rewrite vsum_cons.
- rewrite Rplus_comm.
- f_equal.
- Qed.
-
- Lemma df_eval_backprop_delta_by_unit_partvec {n}
- (σ:df_env) (df:DefinedFunction UnitAnn (DTVector n)) (s: SubVar) grad_env (grad d:Vector float n):
-
- fully_closed_over df
- (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) ->
- vartlookup grad_env (s, DTfloat) <> None ->
-
- (forall i : {n' : nat | n' < n},
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- (scaleUnitVector n i (grad i) 0%R)) = Some (d i)) ->
- df_eval_backprop_delta σ df (s, DTfloat) grad_env grad =
- Some (vsum d).
- Proof.
- intros.
- replace (grad) with (mergeVectorZero n grad 0%R).
- - erewrite df_eval_backprop_delta_by_unit_partvec_bounded; try eassumption.
- now rewrite vfirstn_eq.
- - apply functional_extensionality; unfold mergeVectorZero; intros [x pff]; simpl.
- destruct (lt_dec x n); trivial; lia.
- Unshelve.
- lia.
- Qed.
-
- Lemma df_eval_backprop_delta_by_unit_parts {n}
- (σ:df_env) (df:DefinedFunction UnitAnn (DTVector n)) (s: SubVar) grad_env (d:Vector float n):
-
- fully_closed_over df
- (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) ->
- vartlookup grad_env (s, DTfloat) <> None ->
-
- (forall i : {n' : nat | n' < n},
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
-
- (UnitVector n i)) = Some (d i)) ->
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- (ConstVector n 1%R)) = Some (vsum d).
- Proof.
- intros.
- replace (ConstVector n 1%R) with (mergeVectorZero n (ConstVector n 1%R) 0%R).
- - erewrite df_eval_backprop_delta_by_unit_partvec_bounded; try eassumption.
- now rewrite vfirstn_eq.
- - apply functional_extensionality; unfold mergeVectorZero, ConstVector; intros [x pff]; simpl.
- destruct (lt_dec x n); trivial; lia.
- Unshelve.
- lia.
- Qed.
-
- Lemma scalarMult_mult {T} a b grad : scalarMult T a (scalarMult T b grad) = scalarMult T (a*b)%R grad.
- Proof.
- destruct T; simpl.
- - lra.
- - apply vec_eq_eq; intros ?; lra.
- - do 2 (apply vec_eq_eq; intros ?); lra.
- Qed.
-
- Corollary scalarMult_backprop_grad0 {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (s: SubVar) (grad_env :df_env) (grad : definition_function_types_interp T) :
- let v := (s, DTfloat) in
- vartlookup grad_env v <> None ->
- df_eval_backprop_deriv σ df grad_env (scalarMult T 0%R grad) <> None ->
- df_eval_backprop_delta σ df v grad_env (scalarMult T 0%R grad) = Some 0%R.
- Proof.
- simpl; intros.
- (* This allows us to drop the (unneeded) assumption
- df_eval_backprop_deriv σ df grad_env1 grad <> None
- *)
- replace (scalarMult T 0%R grad) with (scalarMult T 0%R (scalarMult T 0%R grad)).
- - erewrite scalarMult_backprop_grad_scalar; try eassumption.
- + simpl.
- unfold df_eval_backprop_delta.
- simpler2; simpl.
- destruct (df_eval_backprop_deriv σ df grad_env (scalarMult T 0%R grad)); simpl; [| congruence].
- f_equal; lra.
- + now rewrite scalarMult_mult, Rmult_0_l.
- - now rewrite scalarMult_mult, Rmult_0_l.
- Qed.
-
- Lemma scaleUnitVec_vec_plus_distr m n i (x y:Vector float m) c :
- vec_eq c (dfti_gen_plus (T:=DTVector m) c c) ->
- (scaleUnitVector n i (dfti_gen_plus (T:=DTVector m) x y) c) =
- (dfti_gen_plus (T:=DTMatrix n m) (scaleUnitVector n i x c) (scaleUnitVector n i y c)).
- Proof.
- intros fa.
- do 2 (apply functional_extensionality; intro).
- unfold scaleUnitVector; simpl.
- match_destr
- ; try now specialize (fa x1); simpl in fa.
- Qed.
-
-
- Lemma df_eval_backprop_delta_by_unit_partmat_outer_bounded {n m} bound (pf:(bound<=n)%nat)
- (σ:df_env) (df:DefinedFunction UnitAnn (DTMatrix n m)) (s: SubVar) grad_env (grad d:Matrix float n m):
-
-
- fully_closed_over df
- (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) ->
- vartlookup grad_env (s, DTfloat) <> None ->
-
- (forall (i : {n' : nat | n' < n}),
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- (scaleUnitVector n i (grad i) (ConstVector m 0%R))) = Some (vsum (d i))) ->
-
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- (mergeVectorZero bound grad (ConstVector m 0%R))) = Some (vsum (vfirstn (vmap vsum d) bound pf)).
- Proof.
- intros closed lo fa.
- induction bound.
- - rewrite vfirstn0, vsum_nil.
-
- replace (mergeVectorZero 0 grad (ConstVector m 0%R)) with (scalarMult (DTMatrix n m) 0%R (mergeVectorZero 0 grad (ConstVector m 0%R))).
- + apply scalarMult_backprop_grad0; simpl in *; trivial.
- apply backprop_deriv_fully_closed_not_none; trivial.
- + unfold scalarMult, ConstVector, mergeVectorZero.
- do 2 (apply functional_extensionality; intros); simpl.
- lra.
- - assert (pf2:bound < n) by lia.
- assert (pf3:bound <= n) by lia.
- simpl.
- generalize (mergeVectorSzero_mat (m:=m) _ pf2 grad ); intros HH; unfold Vector in HH.
- simpl float in *.
- rewrite HH; clear HH.
- erewrite backprop_grad_sum; try eassumption.
- specialize (IHbound pf3).
- rewrite IHbound.
- rewrite fa.
- simpl.
- f_equal.
- generalize (vector_Sn_split (vfirstn (vmap vsum d) (S bound) pf)); intros eqq1.
- apply vec_eq_eq in eqq1.
- simpl in *.
- rewrite eqq1.
- erewrite vfirstn_vdrop_last.
- erewrite vlast_vfirstn.
- rewrite vsum_cons.
- rewrite Rplus_comm.
- f_equal.
- rewrite vmap_nth.
- eauto.
- Qed.
-
- Lemma df_eval_backprop_delta_by_unit_partmat_inner_bounded {n m} bound (pf:(bound<=m)%nat)
- (σ:df_env) (df:DefinedFunction UnitAnn (DTMatrix n m)) (s: SubVar) grad_env (grad d:Matrix float n m) i:
-
-
- fully_closed_over df
- (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) ->
- vartlookup grad_env (s, DTfloat) <> None ->
-
- (forall j : {n' : nat | n' < m},
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- ((scaleUnitVector n i
- (scaleUnitVector m j (grad i j) 0%R) (ConstVector m 0%R)))) = Some (d i j)) ->
-
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- (scaleUnitVector n i (mergeVectorZero bound (grad i) 0%R) (ConstVector m 0%R))) = Some (vsum (vfirstn (d i) bound pf)).
- Proof.
- intros closed lo fa.
- induction bound.
- - rewrite vfirstn0, vsum_nil.
- replace (scaleUnitVector n i (mergeVectorZero 0 (grad i) 0%R) (ConstVector m 0%R)) with (scalarMult (DTMatrix n m) 0%R (mergeVectorZero 0 grad (ConstVector m 0%R))).
- + apply scalarMult_backprop_grad0; simpl in *; trivial.
- apply backprop_deriv_fully_closed_not_none; trivial.
- + unfold scalarMult, ConstVector, mergeVectorZero, scaleUnitVector.
- do 2 (apply functional_extensionality; intros); simpl.
- match_destr; lra.
- - assert (pf2:bound < m) by lia.
- assert (pf3:bound <= m) by lia.
- simpl.
- generalize (mergeVectorSzero _ pf2 (grad i) ); intros HH; unfold Vector in HH.
- simpl float in *.
- rewrite HH; clear HH.
- rewrite (scaleUnitVec_vec_plus_distr m n i) by (intros ?; unfold ConstVector; simpl; lra).
- erewrite backprop_grad_sum; try eassumption.
- specialize (IHbound pf3).
- simpl in *.
- rewrite IHbound.
- simpl.
- rewrite fa.
- simpl.
- f_equal.
- generalize (vector_Sn_split (vfirstn (d i) (S bound) pf)); intros eqq1.
- apply vec_eq_eq in eqq1.
- simpl in *.
- rewrite eqq1.
- erewrite vfirstn_vdrop_last.
- erewrite vlast_vfirstn.
- rewrite vsum_cons.
- rewrite Rplus_comm.
- f_equal.
- Qed.
-
- Lemma df_eval_backprop_delta_by_unit_partmat {n m}
- (σ:df_env) (df:DefinedFunction UnitAnn (DTMatrix n m)) (s: SubVar) grad_env
- (grad d:Matrix float n m):
- fully_closed_over
- df
- (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve)
- σ) ->
- vartlookup grad_env (s, DTfloat) <> None ->
- (forall (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) ,
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- (scalarMult (DTMatrix n m) (grad i j)
- (UnitMatrix n m i j))) =
- Some (d i j)) ->
- df_eval_backprop_delta σ df (s, DTfloat) grad_env grad =
- Some (msum d).
- Proof.
- intros closed lo fa.
- replace (grad) with (mergeVectorZero n grad (ConstVector m 0%R)).
- - erewrite df_eval_backprop_delta_by_unit_partmat_outer_bounded; try eassumption.
- + now rewrite vfirstn_eq.
- + intros i.
- specialize (fa i).
- simpl in fa.
- replace (grad i) with (mergeVectorZero m (grad i) 0%R).
- 2: {
- unfold mergeVectorZero; simpl; apply functional_extensionality; intros; destruct x.
- simpl.
- match_destr; lia.
- }
-
- erewrite (df_eval_backprop_delta_by_unit_partmat_inner_bounded m (le_refl m) σ df s grad_env)
- ; try eassumption.
- * now rewrite vfirstn_eq.
- * intros j.
- specialize (fa j); simpl in *.
- rewrite <- fa.
- f_equal.
- do 2 (apply functional_extensionality; intros).
- unfold scaleUnitVector, ConstVector, UnitMatrix; simpl.
- repeat match_destr; lra.
- - apply functional_extensionality; unfold mergeVectorZero; intros [x pff]; simpl.
- destruct (lt_dec x n); trivial; lia.
- Unshelve.
- lia.
- Qed.
-
- Lemma df_eval_backprop_delta_by_unit_parts_mat {n m}
- (σ:df_env) (df:DefinedFunction UnitAnn (DTMatrix n m)) (s: SubVar) grad_env
- (d:Matrix float n m):
- fully_closed_over
- df
- (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve)
- σ) ->
- vartlookup grad_env (s, DTfloat) <> None ->
-
- (forall (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) ,
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
-
- (UnitMatrix n m i j)) = Some (d i j)) ->
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- (ConstMatrix n m 1%R)) = Some (msum d).
- Proof.
- intros.
- apply df_eval_backprop_delta_by_unit_partmat; trivial.
- intros.
- rewrite scalarMult_backprop_grad_scalar with (grad_env2 := grad_env); trivial.
- unfold lift.
- rewrite H1.
- f_equal.
- unfold scalarMult, ConstMatrix; simpl; lra.
- apply backprop_deriv_fully_closed_not_none; trivial.
- apply backprop_deriv_fully_closed_not_none; trivial.
- Qed.
-
- Corollary scalarMult_backprop_list_env_iter_grad0 {T} (σ:df_env) (s: SubVar) (grad_env :df_env) (grad : definition_function_types_interp T) old n x l :
- let v := (s, DTfloat) in
- (forall j : {n' : nat | n' < n},
- fully_closed_over (x j)
- (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ)) ->
- vartlookup grad_env v = Some old ->
- lift (fun e => subvar v e old)
- (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv (Ann:=UnitAnn) σ (x i) env (scalarMult T 0%R grad)) (Some grad_env) l) = Some 0%R.
- Proof.
- simpl; intros.
- revert grad_env old H0.
- induction l; simpl; intros.
- - unfold subvar; simpl.
- rewrite H0; f_equal.
- lra.
- - unfold lift in *.
- case_eq (df_eval_backprop_deriv σ (x a) grad_env (scalarMult T 0%R grad))
- ; intros.
- + apply IHl.
- * generalize (scalarMult_backprop_grad0 σ (x a) s grad_env grad); intros HH.
- simpl in HH.
- cut_to HH; try congruence.
- unfold df_eval_backprop_delta in HH.
- rewrite H0 in HH.
- rewrite H1 in HH.
- simpl in HH.
- invcs HH.
- unfold subvar in H3; simpl in H3.
- simpler2.
- rewrite eqq in eqq0; invcs eqq0; subst.
- f_equal.
- lra.
- + eelim backprop_deriv_fully_closed_not_none; eauto.
- Qed.
-
- Lemma list_env_iter_gen_delta0 {n} (s: SubVar) (init_env : df_env)
- (f: {n' : nat | n' < n} -> df_env -> option df_env) (l : list {n' : nat | n' < n}):
- let v := (s,DTfloat) in
- vartlookup init_env v <> None ->
- (forall (i : {n' : nat | n' < n}) (env : df_env),
- (f i env) <> None /\
- (vartlookup env v <> None ->
- match f i env with
- | Some xenv => vartlookup xenv v <> None
- | _ => True
- end ) /\
- (In i l -> vartlookup env v <> None ->
- lift2 (fun e val => subvar v e val) (f i env)
- (vartlookup env v) = Some 0%R)) ->
- lift2 (fun e old => subvar v e old)
- (list_env_iter f (Some init_env) l)
- (vartlookup init_env v) = Some 0%R.
- Proof.
- simpl; intros.
- revert init_env H H0.
- induction l.
- - intros.
- simpl; f_equal.
- unfold subvar; simpl.
- match_option; [|tauto].
- f_equal.
- lra.
- - intros.
- assert (H0c := H0).
- specialize (H0 a init_env).
- destruct H0; destruct H1.
- simpl.
- case_eq (f a init_env); [intros|tauto].
- specialize (IHl d).
- replace (vartlookup init_env (s, DTfloat)) with (vartlookup d (s, DTfloat)).
- + apply IHl.
- * specialize (H1 H).
- now rewrite H3 in H1.
- * intros.
- specialize (H0c i env).
- destruct H0c; destruct H5.
- split; trivial.
- split; trivial.
- intros.
- cut_to H6; trivial.
- simpl; tauto.
- + case_eq (vartlookup init_env (s, DTfloat)); [intros|tauto].
- rewrite H3, H4 in H2; simpl in H2.
- cut_to H2; try tauto.
- unfold lift2 in H2; simpl in H2.
- inversion H2.
- unfold subvar in H6; simpl in H6.
- rewrite H3 in H1.
- specialize (H1 H).
- case_eq ( vartlookup d (s, DTfloat)); [intros|tauto].
- rewrite H5 in H6.
- f_equal; lra.
- congruence.
- Qed.
-
- Lemma list_env_iter_gen_delta {n} (s: SubVar) (init_env : df_env) (old : float)
- (f: {n' : nat | n' < n} -> df_env -> option df_env) (i0 : {n' : nat | n' < n}):
- let v := (s,DTfloat) in
- vartlookup init_env v = Some old ->
- (forall (i : {n' : nat | n' < n}) (env : df_env),
- (f i env) <> None /\
- (vartlookup env v <> None ->
- match f i env with
- | Some xenv => vartlookup xenv v <> None
- | _ => True
- end ) /\
- (forall (env2 : df_env),
- match vartlookup env v, vartlookup env2 v, f i env, f i env2 with
- | Some val1, Some val2, Some xenv, Some xenv2 =>
- subvar v xenv val1 = subvar v xenv2 val2
- | _, _, _, _ => True
- end) /\
- ( i <> i0 -> vartlookup env v <> None ->
- lift2 (fun e val => subvar v e val) (f i env)
- (vartlookup env v) = Some 0%R)) ->
- lift (fun e => subvar v e old) (f i0 init_env) =
- lift (fun e => subvar v e old)
- (list_env_iter f
- (Some init_env) (bounded_seq0 n)).
- Proof.
- simpl; intros.
- unfold bounded_seq0.
- destruct (bounded_seq_break_at 0 n i0) as [b [c [eqq1 [fa1 fa2]]]]; [lia |].
- rewrite eqq1.
- rewrite list_env_iter_app; simpl.
- match_option.
- -
- assert (eqq': subvar (s, DTfloat) d old = 0%R).
- + generalize (list_env_iter_gen_delta0 s init_env f b); intros.
- simpl in H1.
- cut_to H1; try congruence.
- * unfold lift2 in H1.
- rewrite H, eqq in H1.
- now inversion H1.
- * intros.
- specialize (H0 i env).
- destruct H0; destruct H2; destruct H3.
- split; trivial.
- split; trivial.
- intros.
- cut_to H4; trivial.
- rewrite Forall_forall in fa1.
- specialize (fa1 i H5).
- intro eq1; rewrite eq1 in fa1; lia.
- + generalize (vartlookup_list_env_iter s f b); intros vart.
- * specialize (vart init_env d eqq).
- assert (vartinit: vartlookup init_env (s, DTfloat) <> None) by congruence.
- specialize (vart vartinit).
- assert (f i0 d <> None) by apply H0.
- case_eq (f i0 d); [intros | tauto].
- generalize (list_env_iter_gen_delta0 s d0 f c); simpl; intros.
- cut_to H3; try congruence.
- -- unfold lift at 2.
- unfold lift2 in H3.
- match_option.
- ++ rewrite eqq0 in H3.
- match_option_in H3.
- rewrite (split_subvar d0 d1 old d2);trivial.
- inversion H3.
- rewrite H5.
- unfold lift.
- match_option.
- ** f_equal.
- case_eq (vartlookup d (s, DTfloat)); intros.
- --- rewrite (split_subvar d d0 old d4); trivial.
- rewrite eqq'.
- specialize (H0 i0 init_env).
- destruct H0; destruct H6; destruct H7.
- specialize (H7 d).
- rewrite H, H4, eqq3, H2 in H7.
- rewrite H7; lra.
- --- cut_to vart; [tauto|].
- intros.
- specialize (H0 i env).
- destruct H0; destruct H8.
- specialize (H8 H7).
- now rewrite H6 in H8.
- ** specialize (H0 i0 init_env).
- destruct H0; congruence.
- ++ generalize (list_env_iter_total_fun f d0 c); intros.
- cut_to H4; try congruence.
- apply H0.
- -- cut_to vart.
- ++ specialize (H0 i0 d).
- destruct H0; destruct H4.
- rewrite H2 in H4.
- apply H4; trivial.
- ++ intros.
- specialize (H0 i env).
- destruct H0;destruct H6.
- specialize (H6 H5).
- now rewrite H4 in H6.
- -- intros.
- specialize (H0 i env).
- split.
- apply H0.
- split.
- apply H0.
- intros.
- assert (i <> i0).
- ++ rewrite Forall_forall in fa2.
- specialize (fa2 i H4).
- intro eq1; rewrite eq1 in fa2; lia.
- ++ now apply H0.
- - generalize (list_env_iter_total_fun f init_env b); intros.
- cut_to H1; trivial.
- tauto.
- intros.
- apply H0.
- Qed.
-
- Lemma list_env_iter_vec_delta {n} (σ:df_env)
- (x:Vector (DefinedFunction UnitAnn DTfloat) n) (s: SubVar) grad_env
- (i0 : {n' : nat | n' < n}) (old : float) :
- let v := (s, DTfloat) in
- vartlookup grad_env v = Some old ->
- (forall (j: {n' : nat | n' < n}) ,
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over (x j) vl) ->
- lift (fun e => subvar v e old)
- (df_eval_backprop_deriv σ (x i0) grad_env 1%R) =
- lift (fun e => subvar v e old)
- (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (UnitVector n i0 i))
- (Some grad_env) (bounded_seq0 n)).
- Proof.
- simpl; intros.
- generalize (list_env_iter_gen_delta
- s grad_env old
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (UnitVector n i0 i))
- i0).
- simpl; intros.
- specialize (H1 H).
- rewrite <- H1.
- - unfold UnitVector; simpl.
- now destruct (equiv_dec (` i0) (` i0)); [|congruence].
- - clear H1.
- intros.
- split; [|split].
- + apply backprop_deriv_fully_closed_not_none; auto.
- + intros.
- match_option.
- now apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in eqq;trivial.
- + split.
- * intros.
- do 4 match_option.
- generalize (backprop_indep_env σ (x i) s env env2
- (UnitVector n i0 i)); simpl; intros HH.
- cut_to HH; trivial; try congruence.
- unfold df_eval_backprop_delta in HH.
- rewrite eqq,eqq0,eqq1,eqq2 in HH.
- unfold lift in HH.
- now inversion HH.
- * intros.
- unfold UnitMatrix; simpl.
- destruct (equiv_dec (` i) (` i0)).
- elim H1; destruct i; destruct i0; simpl in *; red in e; subst; apply index_pf_irrel.
- unfold lift2.
- generalize (scalarMult_backprop_grad0 σ (x i) s env 0%R); simpl; intros.
- unfold df_eval_backprop_delta in H3.
- specialize (H3 H2).
- replace (0 * 0)%R with 0%R in H3 by lra.
- match_option.
- match_option; [|tauto].
- -- rewrite eqq0 in H3.
- replace (UnitVector n i0 i) with Fzero in eqq; simpl in eqq.
- ++ rewrite eqq in H3; unfold lift in H3.
- apply H3; congruence.
- ++ unfold UnitVector; simpl.
- destruct (equiv_dec (` i) (` i0)); [|trivial].
- elim H1; destruct i; destruct i0; simpl in *; red in e; subst; apply index_pf_irrel.
- -- assert (df_eval_backprop_deriv σ (x i) env (UnitVector n i0 i) <> None).
- now apply backprop_deriv_fully_closed_not_none.
- tauto.
- Qed.
-
-
- Lemma list_env_iter_backprop_indep_env_vec {m} (σ:df_env)
- (vecdf: Vector (DefinedFunction UnitAnn DTfloat) m)
- (s:SubVar) (env env2:df_env) (grad: Vector float m)
- (old1 old2: float) :
- let v := (s, DTfloat) in
- vartlookup env v = Some old1 ->
- vartlookup env2 v = Some old2 ->
- (let vl := map (fun ve => projT1 ve) σ in
- forall j : {m' : nat | m' < m},
- fully_closed_over (vecdf j) vl) ->
- lift (fun e => subvar (s, DTfloat) e old1)
- (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env0 (grad j))
- (Some env) (bounded_seq0 m)) =
- lift (fun e => subvar (s, DTfloat) e old2)
- (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env0 (grad j))
- (Some env2) (bounded_seq0 m)).
- Proof.
- intros.
- subst v.
- revert old1 old2 env env2 H H0.
- induction (bounded_seq0 m).
- - intros.
- simpl.
- unfold subvar; simpl.
- rewrite H, H0.
- f_equal; lra.
- - intros.
- simpl.
- case_eq (df_eval_backprop_deriv σ (vecdf a) env (grad a)); intros.
- case_eq (df_eval_backprop_deriv σ (vecdf a) env2 (grad a)); intros.
- case_eq (vartlookup d (s, DTfloat)); intros.
- case_eq (vartlookup d0 (s, DTfloat)); intros.
- + specialize (IHl d1 d2 d d0 H4 H5).
- unfold lift.
- do 2 match_option.
- * rewrite eqq, eqq0 in IHl.
- unfold lift in IHl.
- inversion IHl.
- f_equal.
- rewrite (split_subvar d d3 old1 d1); trivial.
- rewrite (split_subvar d0 d4 old2 d2); trivial.
- rewrite H7.
- generalize (backprop_indep_env
- σ (vecdf a) s
- env env2 (grad a)); simpl; intros.
- cut_to H6; trivial; try congruence.
- unfold df_eval_backprop_delta in H6.
- rewrite H, H0, H2, H3 in H6.
- unfold lift in H6.
- inversion H6.
- rewrite H9; lra.
- * generalize (list_env_iter_total_fun
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env0 (grad j))
- d0 l); intros.
- cut_to H6; [tauto|].
- intros.
- apply backprop_deriv_fully_closed_not_none; auto.
- * generalize (list_env_iter_total_fun
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env0 (grad j))
- d l); intros.
- cut_to H6; [tauto|].
- intros.
- apply backprop_deriv_fully_closed_not_none; auto.
- + apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in H3
- ;trivial; congruence.
- + apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in H2
- ;trivial; congruence.
- + assert (df_eval_backprop_deriv σ (vecdf a) env2 (grad a) <> None) by
- (apply backprop_deriv_fully_closed_not_none; auto); tauto.
- + assert (df_eval_backprop_deriv σ (vecdf a) env (grad a) <> None) by
- (apply backprop_deriv_fully_closed_not_none; auto); tauto.
- Qed.
-
- Lemma list_env_iter_mat_delta {n m} (σ:df_env)
- (x:Matrix (DefinedFunction UnitAnn DTfloat) n m) (s: SubVar) grad_env
- (i0 : {n' : nat | n' < n})
- (j0 : {m' : nat | m' < m}) (old : float) :
- let v := (s, DTfloat) in
- vartlookup grad_env v = Some old ->
- (forall (i: {n' : nat | n' < n}) (j: {m' : nat | m' < m}) ,
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over (x i j) vl) ->
- lift (fun e => subvar v e old)
- (df_eval_backprop_deriv σ (x i0 j0) grad_env 1%R) =
- lift (fun e => subvar v e old)
- (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv
- σ (x i j) env0
- (UnitMatrix n m i0 j0 i j)) (Some env)
- (bounded_seq0 m)) (Some grad_env) (bounded_seq0 n)).
- Proof.
- simpl; intros.
- generalize (list_env_iter_gen_delta
- s grad_env old
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv
- σ (x i j) env0
- (UnitMatrix n m i0 j0 i j)) (Some env)
- (bounded_seq0 m)) i0); simpl.
- intros.
- specialize (H1 H).
- rewrite <- H1.
- - clear H1.
- generalize (list_env_iter_gen_delta
- s grad_env old
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i0 j) env0 (UnitMatrix n m i0 j0 i0 j))
- j0); simpl.
- intros.
- specialize (H1 H).
- rewrite <- H1.
- + unfold UnitMatrix; simpl.
- destruct (equiv_dec (` i0) (` i0)); [|congruence].
- now destruct (equiv_dec (` j0) (` j0)); [|congruence].
- + clear H1.
- intros.
- split; [|split].
- * apply backprop_deriv_fully_closed_not_none; auto.
- * intros.
- match_option.
- now apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in eqq;trivial.
- * split.
- -- intros.
- do 4 match_option.
- generalize (backprop_indep_env σ (x i0 i) s env env2
- (UnitMatrix n m i0 j0 i0 i)); simpl; intros HH.
- cut_to HH; trivial; try congruence.
- unfold df_eval_backprop_delta in HH.
- rewrite eqq,eqq0,eqq1,eqq2 in HH.
- unfold lift in HH.
- now inversion HH.
- -- intros.
- unfold UnitMatrix; simpl.
- destruct (equiv_dec (` i0) (` i0)); [|congruence].
- destruct (equiv_dec (` i) (` j0)).
- elim H1; destruct i; destruct j0; simpl in *; red in e0; subst; apply index_pf_irrel.
- unfold lift2.
- generalize (scalarMult_backprop_grad0 σ (x i0 i) s env 0%R); simpl; intros.
- unfold lift2.
- replace (0 * 0)%R with 0%R in H3 by lra.
- unfold df_eval_backprop_delta in H3.
- match_option.
- match_option.
- rewrite eqq0, eqq in H3; unfold lift in H3.
- cut_to H3; congruence.
- rewrite eqq0 in H3.
- apply H3; trivial.
- tauto.
- congruence.
- assert (df_eval_backprop_deriv σ (x i0 i) env 0%R <> None).
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- - split.
- generalize (list_env_iter_total_fun
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (UnitMatrix n m i0 j0 i j))
- env (bounded_seq0 m)); intros.
- apply H2.
- intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- split.
- + intros.
- match_option.
- apply (vartlookup_list_env_iter
- s (fun (j : {m' : nat | m' < m})
- (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (UnitMatrix n m i0 j0 i j))
- (bounded_seq0 m) env d); trivial.
- intros.
- apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in H3; trivial.
- + split.
- * intros.
- do 4 match_option.
- generalize (list_env_iter_backprop_indep_env_vec
- σ (x i) s env env2
- (UnitMatrix n m i0 j0 i)
- d d0); simpl; intros.
- specialize (H2 eqq eqq0).
- specialize (H2 (H0 i)).
- rewrite eqq1, eqq2 in H2.
- unfold lift in H2.
- now inversion H2.
- * intros.
- replace (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (UnitMatrix n m i0 j0 i j))
- with
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0
- (scalarMult DTfloat 0%R 0%R)).
- case_eq (vartlookup env (s, DTfloat)); [intros | tauto].
- apply scalarMult_backprop_list_env_iter_grad0; trivial.
- apply functional_extensionality; intros.
- apply functional_extensionality; intros.
- f_equal.
- unfold scalarMult, UnitMatrix; simpl.
- destruct (equiv_dec (` i) (` i0)).
- red in e.
- elim H2; destruct i; destruct i0; simpl in *; subst.
- apply index_pf_irrel.
- lra.
- Qed.
-
- Lemma list_env_iter_matvec_delta {m n} (σ:df_env)
- (df2:DefinedFunction UnitAnn (DTVector m)) (s: SubVar) grad_env
- (i0 : {n' : nat | n' < n})
- (j0 : {m' : nat | m' < m}) (old : float) :
- vartlookup grad_env (s, DTfloat) = Some old ->
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df2 vl ->
- (lift (fun e => subvar (s, DTfloat) e old)
- (df_eval_backprop_deriv
- σ df2 grad_env
- (UnitVector m j0)) =
- (lift (fun e => subvar (s, DTfloat) e old)
- (list_env_iter
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv
- σ df2 env
- ((UnitMatrix n m i0 j0) i))
- (Some grad_env)
- (bounded_seq0 n)))).
- Proof.
- simpl; intros.
- generalize (list_env_iter_gen_delta
- s grad_env old
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv
- σ df2 env
- ((UnitMatrix n m i0 j0) i))
- i0).
- simpl; intros.
- specialize (H1 H).
- rewrite <- H1.
- - unfold UnitMatrix, UnitVector; simpl.
- now destruct (equiv_dec (` i0) (` i0)); [|congruence].
- - clear H1.
- intros.
- split; [|split].
- + apply backprop_deriv_fully_closed_not_none; auto.
- + intros.
- match_option.
- now apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in eqq;trivial.
- + split.
- * intros.
- do 4 match_option.
- generalize (backprop_indep_env σ df2 s env env2
- (UnitMatrix n m i0 j0 i)); simpl; intros HH.
- cut_to HH; trivial; try congruence.
- unfold df_eval_backprop_delta in HH.
- rewrite eqq,eqq0,eqq1,eqq2 in HH.
- unfold lift in HH.
- now inversion HH.
- * intros.
- unfold UnitMatrix; simpl.
- destruct (equiv_dec (` i) (` i0)).
- elim H1; destruct i; destruct i0; simpl in *; red in e; subst; apply index_pf_irrel.
- unfold lift2.
- generalize (scalarMult_backprop_grad0 σ df2 s env (ConstVector m 0%R)); simpl; intros.
- unfold df_eval_backprop_delta in H3.
- specialize (H3 H2).
- match_option.
- match_option; [|tauto].
- -- rewrite eqq0 in H3.
- replace (fun i : {n' : nat | n' < m} => (0 * ConstVector m 0 i)%R) with
- (fun i : {n' : nat | n' < m} => 0%R) in H3.
- ++ rewrite eqq in H3; unfold lift in H3.
- apply H3; congruence.
- ++ apply functional_extensionality; intros.
- unfold ConstVector; simpl.
- lra.
- -- assert (df_eval_backprop_deriv σ df2 env (fun _ => 0%R) <> None).
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- Qed.
-
- Lemma vmap_eta {A B} {n} (f:A->B) (d:Vector A n) : vmap f d = vmap (fun x => f x) d.
- Proof.
- now apply vmap_ext.
- Qed.
-
- Lemma vsum_eta {n} (d:Vector float n) : vsum d = vsum (fun i => d i).
- Proof.
- now apply vsum_ext.
- Qed.
-
- Lemma msum_unitvector m n x d1 d0 :
- msum
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (UnitVector m x i * d1 i j * d0 j)%R) =
- vsum
- (fun j : {n' : nat | n' < n} =>
- (d1 x j * d0 j)%R).
- Proof.
- unfold msum.
-
- transitivity (
- vsum
- (fun i : {n' : nat | n' < m} =>
- vsum
- (fun (j : {m' : nat | m' < n}) => (UnitVector m x i * d1 i j * d0 j)%R))).
- { apply vsum_ext; intros ?.
- now rewrite vmap_nth.
- }
- transitivity (
- vsum
- (fun i : {n' : nat | n' < m} =>
- ((UnitVector m x i * vsum
- (fun (j : {m' : nat | m' < n}%nat) => d1 i j * d0 j))%R))).
- {
- apply vsum_ext; intros ?.
- rewrite vsum_mult.
- apply vsum_ext; intros ?.
- lra.
- }
- transitivity (
- vsum
- (fun i : {n' : nat | n' < m} =>
- (vsum (fun j : {m' : nat | (m' < n)%nat} => d1 i j * d0 j) * UnitVector m x i)%R)).
- {
- apply vsum_ext; intros ?; lra.
- }
- now rewrite vsum_unitvector.
- Qed.
-
- Lemma vsum_as_sum {n} (v:Vector float n) : vsum v = fold_right Rplus R0 (vector_to_list v).
- Proof.
- rewrite vsum_alt_eq.
- unfold vector_to_list, vector_fold_right.
- induction n.
- - rewrite vector_fold_right_dep_0; trivial.
- - repeat rewrite vector_fold_right_dep_Sn.
- now rewrite IHn.
- Qed.
-
- Lemma msum_as_sum {m n} (mat:Matrix float m n) : msum mat = fold_right Rplus R0 (matrix_to_list mat).
- Proof.
- unfold msum, matrix_to_list, matrix_to_list_list.
- transitivity
- (vsum (fun i => fold_right Rplus R0 (vector_to_list (mat i)))).
- { apply vsum_ext; intros [??].
- now rewrite vmap_nth, vsum_as_sum.
- }
- rewrite vsum_as_sum.
- rewrite fold_right_plus_concat.
- rewrite map_vector_to_list_vmap.
- f_equal.
- apply vector_to_list_ext.
- intros [??].
- now rewrite vmap_nth.
- Qed.
-
- Lemma transpose_perm_bounded {m n : nat} (mat : Matrix float m n) bound_m pf_m bound_n pf_n:
- Permutation
- (concat
- (vector_fold_right_bounded_dep (fun _ : nat => Datatypes.cons) []
- (fun i : {n' : nat | n' < m} =>
- vector_fold_right_bounded_dep (fun _ : nat => Datatypes.cons) [] (mat i) bound_n pf_n) bound_m pf_m))
- (concat
- (vector_fold_right_bounded_dep (fun _ : nat => Datatypes.cons) []
- (fun i : {n' : nat | n' < n} =>
- vector_fold_right_bounded_dep (fun _ : nat => Datatypes.cons) [] (transpose mat i) bound_m pf_m) bound_n pf_n)).
- Proof.
- Hint Constructors Permutation : fml.
- revert bound_n pf_n.
- induction bound_m; intros; simpl.
- - induction bound_n; simpl; trivial.
- - rewrite IHbound_m.
- clear IHbound_m.
- induction bound_n; simpl; trivial.
- rewrite <- IHbound_n.
- clear IHbound_n.
- apply Permutation_cons; trivial.
- unfold transpose; simpl.
- repeat rewrite <- app_ass.
- apply Permutation_app; trivial.
- apply Permutation_app_comm.
- Qed.
-
- Lemma transpose_perm {m n} (mat : Matrix float m n) :
- Permutation (matrix_to_list mat) (matrix_to_list (transpose mat)).
- Proof.
- apply transpose_perm_bounded.
- Qed.
-
- Lemma msum_transpose {m n} (mat : Matrix float m n) :
- msum mat = msum (transpose mat).
- Proof.
- repeat rewrite msum_as_sum.
- apply fold_right_perm; intros; try lra.
- apply transpose_perm.
- Qed.
-
- Ltac match_nested_case :=
- match goal with
- | [|- context[match match ?x with _ => _ end with _ => _ end]] =>
- let eqq := fresh "eqq" in
- case_eq x
- ; [intros ? eqq | intros eqq]
- end.
-
- Ltac match_nested_case_in H :=
- match H with
- | context[match match ?x with _ => _ end with _ => _ end] =>
- let eqq := fresh "eqq" in
- case_eq x
- ; [intros ? eqq | intros eqq]
- ; rewrite eqq in H
- end.
-
- Theorem df_eval_deriv_genvar_same (σ:df_env) (df:DefinedFunction UnitAnn DTfloat) (v:SubVar) :
- let vl := map (fun ve => projT1 ve) σ in
- is_scalar_function df ->
- fully_closed_over df vl ->
- let forward := df_eval_deriv_gen_top σ df (v, DTfloat) in
- lift transpose_lifted_type forward = df_eval σ (df_deriv df (v,DTfloat)).
- Proof.
- simpl.
- intros is_scalar.
- generalize is_scalar.
- pattern df.
- revert df is_scalar.
- DefinedFunction_scalar_cases (apply is_scalar_function_ind) Case; simpl; trivial; intros
- ; try
- ( cut_to H; trivial
- ;rewrite <- H
- ;assert (df_eval σ e <> None) by (apply eval_fully_closed_not_none; trivial)
- ;unfold lift, df_eval_deriv_gen_top; simpl
- ;match_nested_case; [|tauto]
- ;now match_nested_case).
-
- - Case "Var"%string.
- match_nested_case.
- + red in e.
- inversion e; subst.
- now refl_simpler; simpl.
- + now intros.
- - Case "Plus"%string.
- destruct H1.
- destruct is_scalar.
- cut_to H; trivial.
- cut_to H0; trivial.
- unfold lift, df_eval_deriv_gen_top in *; simpl in *.
- match_option_in H.
- + match_option_in H0.
- * rewrite <- H.
- now rewrite <- H0.
- * assert ( df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- - Case "Minus"%string.
- destruct H1.
- destruct is_scalar.
- cut_to H; trivial.
- cut_to H0; trivial.
- unfold lift, df_eval_deriv_gen_top in *; simpl in *.
- match_option_in H.
- + match_option_in H0.
- * rewrite <- H.
- now rewrite <- H0.
- * assert ( df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- - Case "Times"%string.
- destruct H1.
- destruct is_scalar.
- cut_to H; trivial.
- cut_to H0; trivial.
- unfold lift, df_eval_deriv_gen_top in *; simpl in *.
- match_nested_case; trivial.
- match_option_in H.
- + match_nested_case.
- * match_option_in H0.
- -- rewrite <- H.
- now rewrite <- H0.
- -- assert ( df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- * assert (df_eval σ r <> None)
- ; [apply eval_fully_closed_not_none; trivial | tauto].
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- - Case "Divide"%string.
- destruct H1.
- destruct is_scalar.
- cut_to H; trivial.
- cut_to H0; trivial.
- unfold lift, df_eval_deriv_gen_top in *; simpl in *.
- match_nested_case; trivial.
- match_option_in H.
- + match_nested_case.
- * match_option_in H0.
- -- rewrite <- H.
- now rewrite <- H0.
- -- assert ( df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- * assert (df_eval σ r <> None)
- ; [apply eval_fully_closed_not_none; trivial | tauto].
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- + assert (df_eval σ l <> None)
- ; [apply eval_fully_closed_not_none; trivial | tauto].
- - Case "Sign"%string.
- unfold lift, df_eval_deriv_gen_top in *; simpl in *.
- match_nested_case; [trivial|].
- assert (df_eval_deriv_genvar σ e [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- - Case "PSign"%string.
- unfold lift, df_eval_deriv_gen_top in *; simpl in *.
- match_nested_case; [trivial|].
- assert (df_eval_deriv_genvar σ e [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- - Case "Max"%string.
- destruct is_scalar; destruct H1.
- cut_to H; trivial.
- cut_to H0; trivial.
- unfold lift, df_eval_deriv_gen_top in *; simpl in *.
- assert (df_eval σ l <> None) by (apply eval_fully_closed_not_none; trivial).
- assert (df_eval σ r <> None) by (apply eval_fully_closed_not_none; trivial).
- match_nested_case; [|tauto].
- match_nested_case; [|tauto].
- match_option_in H.
- match_option_in H0.
- + rewrite <- H.
- rewrite <- H0.
- unfold pos_sign; simpl.
- case_eq (Rle_dec d d0); intros; f_equal.
- destruct (Rge_dec (d0 - d) 0); lra.
- destruct (Rge_dec (d0 - d) 0); lra.
- + assert (df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- Qed.
-
- Lemma yay {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (hs:has_scalar_functions df) (s: SubVar) grad_env :
- let v := (s, DTfloat) in
- vartlookup grad_env v <> None ->
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl ->
- let forward := df_eval_deriv_gen_top σ df v in
- let backward := df_eval_backward_gen_top σ df v grad_env in
- lift transpose_lifted_type forward = backward.
- Proof.
- simpl.
- intros vin closed.
- revert grad_env vin closed.
- unfold df_eval_deriv_gen_top, df_eval_backward_gen_top.
- pattern T, df.
- revert T df hs.
- DefinedFunction_cases (apply DefinedFunction_ind_unit_has_scalar_functions) Case
- ; simpl; intros.
- - Case "Number"%string.
- unfold subvar; simpl.
- match_destr; [ | tauto].
- f_equal; lra.
- - Case "Constant"%string.
- unfold subvar; simpl.
- match_destr; simpl.
- + match_destr; [ | tauto].
- f_equal; lra.
- + match_destr; [ | tauto].
- erewrite vectoro_to_ovector_forall_some_b_strong
- ; simpl; trivial; intros.
- unfold ConstVector.
- f_equal; lra.
- + match_destr; [ | tauto].
- unfold matrixo_to_omatrix.
- repeat (erewrite vectoro_to_ovector_forall_some_b_strong
- ; simpl; trivial; intros).
- unfold ConstMatrix.
- f_equal; lra.
- - Case "DVector"%string.
- unfold lift, two_vector_env_iter_alt.
- case_eq (vartlookup grad_env (s, DTfloat)); [intros|tauto].
- match_option.
- + specialize (apply vectoro_to_ovector_forall_some_f eqq); intros.
- symmetry.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- unfold snd in H.
- specialize (H i grad_env); simpl in H.
- rewrite vforall_forall in closed.
- assert (closedb := closed).
- specialize (closed i).
- specialize (H vin closed); simpl in H.
- rewrite H0 in H.
- specialize (H1 i); simpl in H1.
- rewrite H1 in H; simpl in H.
- unfold lift in H.
- match_option_in H.
- specialize (apply vectoro_to_ovector_forall_some_f eqq); intros.
- specialize (H2 i); simpl in H2.
- destruct i; simpl.
- unfold UnitVector; simpl.
- generalize (list_env_iter_total_fun
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env
- (if equiv_dec (` i) x0 then 1%R else 0%R))
- grad_env (bounded_seq0 n)); intros.
- cut_to H3.
- match_option; [|tauto].
- * rewrite H.
- generalize (list_env_iter_vec_delta σ x s grad_env
- (exist (fun n' : nat => n' < n) x0 l) d).
- intros.
- simpl in H4.
- specialize (H4 H0 closedb).
- rewrite eqq0 in H4.
- unfold UnitVector in H4; simpl in H4.
- rewrite eqq1 in H4.
- unfold lift in H4.
- symmetry; trivial.
- * intros.
- apply backprop_deriv_fully_closed_not_none.
- apply closedb.
- + specialize (vectoro_to_ovector_exists_None eqq); intros.
- destruct H1.
- generalize (eval_deriv_genvar_fully_closed_not_none σ (x x0) [mk_env_entry (s, DTfloat) 1%R]); intros.
- rewrite vforall_forall in closed.
- now specialize (H1 (closed x0)).
- - Case "DMatrix"%string.
- unfold lift, matrixo_to_omatrix.
- case_eq (vartlookup grad_env (s, DTfloat)); [intros|tauto].
- match_option.
- + specialize (apply vectoro_to_ovector_forall_some_f eqq); intros.
- symmetry.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H1 i); simpl in H1.
- specialize (apply vectoro_to_ovector_forall_some_f H1); intros.
- specialize (H2 i0); simpl in H2.
- unfold two_matrix_env_iter_alt.
- specialize (H i i0 grad_env); simpl in H.
- rewrite vforall_forall in closed.
- assert (closedb := closed).
- specialize (closed i).
- rewrite vforall_forall in closed.
- specialize (closed i0).
- specialize (H vin closed); simpl in H.
- rewrite H2 in H; simpl in H.
- rewrite H0 in H; simpl in H.
- unfold lift in H.
- match_option_in H.
- rewrite H.
- destruct i.
- destruct i0.
- unfold lift; simpl.
- generalize (list_env_iter_total_fun
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv
- σ (x i j) env0
- (UnitMatrix n m (exist (fun n' : nat => n' < n) x0 l)
- (exist (fun n' : nat => n' < m) x1 l0) i j))
- (Some env)
- (bounded_seq0 m))
- grad_env (bounded_seq0 n)); intros.
- cut_to H3.
- match_option; [|tauto].
- * generalize (list_env_iter_mat_delta σ x s grad_env
- (exist (fun n' : nat => n' < n) x0 l)
- (exist (fun n' : nat => n' < m) x1 l0) d).
- intros.
- simpl in H4.
- specialize (H4 H0).
- cut_to H4.
- rewrite eqq0 in H4.
- rewrite eqq1 in H4.
- unfold lift in H4.
- symmetry; trivial.
- intros.
- specialize (closedb i).
- rewrite vforall_forall in closedb.
- apply closedb.
- * intros.
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- specialize (closedb a).
- rewrite vforall_forall in closedb.
- apply closedb.
- + specialize (vectoro_to_ovector_exists_None eqq); intros.
- destruct H1.
- specialize (vectoro_to_ovector_exists_None e); intros.
- destruct H1.
- symmetry.
- generalize (eval_deriv_genvar_fully_closed_not_none σ (x x0 x1) [mk_env_entry (s, DTfloat) 1%R]); intros.
- rewrite vforall_forall in closed.
- specialize (closed x0).
- rewrite vforall_forall in closed.
- now specialize (H1 (closed x1)).
- - Case "Var"%string.
- unfold equiv_dec, vart_eqdec; simpl.
- destruct (vart_dec v (s, DTfloat)).
- + destruct v.
- inversion e.
- subst; simpl.
- refl_simpler; simpl.
- case_eq (vartlookup grad_env (s, DTfloat)); [intros |tauto].
- simpl; f_equal.
- symmetry.
- now apply subvar_addvar_scalar_eq.
- + case_eq (vartlookup grad_env (s, DTfloat)); [intros|tauto].
- destruct v.
- unfold snd.
- destruct d0.
- * simpl.
- unfold lift.
- destruct (vartlookup grad_env (s0, DTfloat)).
- -- f_equal; simpl; symmetry.
- now apply subvar_addvar_scalar_neq.
- -- f_equal; symmetry.
- unfold subvar; simpl.
- rewrite H; lra.
- * simpl; symmetry.
- apply vectoro_to_ovector_forall_some_b_strong; intros; unfold lift.
- destruct (vartlookup grad_env (s0, DTVector n0)).
- -- f_equal; simpl.
- unfold ConstVector.
- now apply subvar_addvar_scalar_neq.
- -- f_equal; unfold ConstVector.
- unfold subvar; simpl.
- rewrite H; lra.
- * simpl; symmetry.
- unfold matrixo_to_omatrix.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- unfold lift.
- destruct (vartlookup grad_env (s0, DTMatrix m n0)).
- -- f_equal; unfold ConstMatrix.
- now apply subvar_addvar_scalar_neq.
- -- f_equal; unfold ConstMatrix.
- unfold subvar; simpl.
- rewrite H; lra.
- - Case "Plus"%string.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H.
- case_eq (df_eval_backprop_deriv σ l grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- destruct closed.
- specialize (H H1).
- invcs H.
- { specialize (H0 d1).
- cut_to H0;
- [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))| apply H2].
- match_option
- ; rewrite eqq2 in H0
- ; simpl in *.
- - match_option_in H0.
- unfold lift in H0.
- match_option_in H0.
- invcs H0.
- simpl.
- f_equal.
- unfold subvar; simpl.
- rewrite eqq3.
- match_option; lra.
- - match_option_in H0; simpl.
- + unfold lift in *.
- match_option_in H0.
- + elim (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat) vin eqq3).
- }
- + destruct closed.
- specialize (H H1).
- congruence.
- + destruct closed.
- specialize (H H1).
- congruence.
- + destruct closed.
- specialize (H H1).
- case_eq (df_eval_backprop_deriv σ l grad_env 1%R)
- ; [intros ? eqq2 | intros eqq2].
- rewrite eqq2 in H.
- simpl in *.
- congruence.
- rewrite eqq2 in H.
- now apply H.
- - Case "Minus"%string.
- destruct closed.
- specialize (H grad_env vin H1).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H.
- case_eq (df_eval_backprop_deriv σ l grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- invcs H.
- { specialize (H0 d1).
- cut_to H0;
- [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- | apply H2].
- match_option
- ; rewrite eqq2 in H0
- ; simpl in *.
- - match_option_in H0.
- unfold lift in H0.
- match_option_in H0.
- invcs H0.
- simpl.
- f_equal.
- unfold subvar; simpl.
- rewrite eqq3.
- match_option.
- + case_eq (df_eval_backprop_deriv σ r d1 (- (1))%R); intros.
- * unfold lift.
- f_equal.
- match_option.
- -- generalize (scalarMult_backprop_grad_scalar σ r s d1 d1 1%R (-1)%R)
- ; intros; simpl in H0.
- cut_to H0.
- ++ unfold df_eval_backprop_delta in H0.
- rewrite eqq3 in H0; unfold lift in H0; simpl in H0.
- replace (-1 * 1)%R with (- (1))%R in H0 by lra.
- rewrite H, eqq4 in H0; inversion H0.
- unfold subvar in H0; simpl in H0.
- rewrite eqq6, eqq5 in H0.
- inversion H0.
- lra.
- ++ congruence.
- ++ congruence.
- ++ simpl; replace (-1 * 1)%R with (- (1))%R by lra; congruence.
- ++ congruence.
- -- generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (s,DTfloat))
- ; intros.
- cut_to H0; congruence.
- * generalize (backprop_deriv_fully_closed_not_none σ r d1 (- (1))%R); intros.
- destruct H0; trivial.
- + generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s,DTfloat))
- ; intros.
- cut_to H; congruence.
- - match_option_in H0.
- + unfold lift.
- unfold lift in H0.
- match_option_in H0.
- match_option.
- specialize (df_eval_backprop_deriv_preserves_lookup_not_none eqq5).
- intros.
- specialize (H (s, DTfloat)).
- generalize (backprop_deriv_fully_closed_not_none σ r d1 (1%R)); intros.
- now destruct H3.
- + generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s,DTfloat))
- ; intros.
- cut_to H; congruence.
- }
- + unfold lift in H; simpl in H.
- match_option_in H.
- - Case "Times"%string.
- destruct closed.
- specialize (H grad_env vin H1).
- simpl in *.
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H3 H1).
- match_case; [intros|tauto].
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- generalize (eval_fully_closed_not_none σ r); intros.
- specialize (H5 H2).
- case_eq (df_eval σ r); [intros|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- case_eq (df_eval_backprop_deriv σ l grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- generalize (scalarMult_backprop_grad_scalar σ l s grad_env grad_env 1%R d1)
- ; intros; simpl in H7.
- unfold df_eval_backprop_delta in H7.
- rewrite eqq1 in H7; simpl in H7.
- specialize (H7 vin vin).
- rewrite eqq0 in H7.
- generalize (backprop_deriv_fully_closed_not_none σ l grad_env (d1 * 1)%R); intros.
- specialize (H8 H1); specialize (H7 H8).
- cut_to H7; try discriminate.
- invcs H.
- { specialize (H0 d3).
- cut_to H0;
- [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- | apply H2].
- unfold lift in H7; simpl in H7.
- match_option_in H7.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- ;intros.
- specialize (H vin).
- generalize (backprop_deriv_fully_closed_not_none σ r d3 1%R); intros.
- specialize (H9 H2).
- case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence].
- rewrite H10 in H0.
- unfold lift; simpl.
- unfold lift in H0.
- match_option_in H0.
- - generalize (scalarMult_backprop_grad_scalar σ r s d2 d3 1%R d)
- ; intros.
- unfold df_eval_backprop_delta in H11; simpl in H11.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat));intros.
- specialize (H12 vin).
- case_eq (vartlookup d2 (s, DTfloat)); [intros | congruence].
- specialize (H11 H12 H).
- rewrite H10, H13 in H11.
- generalize (backprop_deriv_fully_closed_not_none σ r d2 (d * 1)%R); intros.
- specialize (H14 H2); specialize (H11 H14 H9).
- unfold lift in H11.
- match_option_in H11; [|congruence]; f_equal.
- rewrite (split_subvar d2 d7 d0 d6); trivial.
- match_option_in H0; invcs H0.
- rewrite eqq5 in H11; invcs H11; invcs H7.
- lra.
- - now match_option_in H0.
- }
- unfold lift in H.
- match_case_in H; intros.
- rewrite H7 in H; congruence.
- generalize (backprop_deriv_fully_closed_not_none σ l grad_env 1%R); intros.
- now specialize (H8 H1).
- - Case "Divide"%string.
- destruct closed.
- specialize (H grad_env vin H1).
- simpl in *.
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H3 H1).
- match_case; [intros|tauto].
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- generalize (eval_fully_closed_not_none σ r); intros.
- specialize (H5 H2).
- case_eq (df_eval σ r); [intros|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- case_eq (df_eval_backprop_deriv σ l grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- generalize (scalarMult_backprop_grad_scalar σ l s grad_env grad_env 1%R (1 / d1)%R)
- ; intros; simpl in H7.
- unfold df_eval_backprop_delta in H7.
- rewrite eqq1 in H7; simpl in H7.
- specialize (H7 vin vin).
- rewrite eqq0 in H7.
- generalize (backprop_deriv_fully_closed_not_none σ l grad_env (1 / d1 * 1)%R); intros.
- specialize (H8 H1) ; specialize (H7 H8).
- cut_to H7; try discriminate.
- replace (1 / d1 * 1)%R with (1/d1)%R in H7 by lra.
- invcs H.
- { specialize (H0 d3).
- cut_to H0;
- [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- | apply H2].
- unfold lift in H7; simpl in H7.
- match_option_in H7.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- ;intros.
- specialize (H vin).
- generalize (backprop_deriv_fully_closed_not_none σ r d3 1%R); intros.
- specialize (H9 H2).
- case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence].
- rewrite H10 in H0.
- unfold lift; simpl.
- unfold lift in H0.
- match_option_in H0.
- - generalize (scalarMult_backprop_grad_scalar σ r s d2 d3 1%R (- d / (d1 * d1))%R)
- ; intros; simpl in H11.
- unfold df_eval_backprop_delta in H11; simpl in H11.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat));intros.
- specialize (H12 vin).
- case_eq (vartlookup d2 (s, DTfloat)); [intros | congruence].
- specialize (H11 H12 H).
- rewrite H10, H13 in H11.
- generalize (backprop_deriv_fully_closed_not_none σ r d2 (- d / (d1 * d1) * 1)%R); intros.
- specialize (H14 H2); specialize (H11 H14 H9).
- unfold lift in H11.
- match_option_in H11; [|congruence]; f_equal.
- rewrite (split_subvar d2 d7 d0 d6); trivial.
- match_option_in H0; invcs H0.
- rewrite eqq5 in H11; invcs H11; invcs H7.
- lra.
- - now match_option_in H0.
- }
- unfold lift in H.
- match_case_in H; intros.
- rewrite H7 in H; congruence.
- generalize (backprop_deriv_fully_closed_not_none σ l grad_env 1%R); intros.
- now specialize (H8 H1).
- - Case "Square"%string.
- specialize (H grad_env vin closed).
- simpl in *.
- generalize (eval_fully_closed_not_none σ e); intros.
- specialize (H0 closed).
- match_case; [intros|tauto].
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- + case_eq (df_eval_backprop_deriv σ e grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (2 * d)%R)
- ; intros; simpl in H2.
- unfold df_eval_backprop_delta in H2.
- rewrite eqq1 in H2; simpl in H2.
- specialize (H2 vin vin).
- rewrite eqq0 in H2.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env (2 * d * 1)%R); intros.
- specialize (H3 closed); specialize (H2 H3).
- cut_to H2; try discriminate.
- invcs H.
- unfold lift in H2; match_option_in H2.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- ;intros.
- unfold lift; now rewrite H2.
- + unfold lift in H.
- match_case_in H; intros.
- rewrite H2 in H; congruence.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros.
- now specialize (H3 closed).
- - Case "Exp"%string.
- specialize (H grad_env vin closed).
- simpl in *.
- generalize (eval_fully_closed_not_none σ e); intros.
- specialize (H0 closed).
- match_case; [intros|tauto].
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- + case_eq (df_eval_backprop_deriv σ e grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (exp d)%R)
- ; intros; simpl in H2.
- unfold df_eval_backprop_delta in H2.
- rewrite eqq1 in H2; simpl in H2.
- specialize (H2 vin vin).
- rewrite eqq0 in H2.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env (exp d * 1)%R); intros.
- specialize (H3 closed); specialize (H2 H3).
- cut_to H2; try discriminate.
- invcs H.
- replace (1 * exp d)%R with (exp d * 1)%R by lra.
- unfold lift in H2; match_option_in H2.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- ;intros.
- unfold lift; rewrite H2.
- f_equal; lra.
- + unfold lift in H.
- match_case_in H; intros.
- rewrite H2 in H; congruence.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros.
- now specialize (H3 closed).
- - Case "Log"%string.
- specialize (H grad_env vin closed).
- simpl in *.
- generalize (eval_fully_closed_not_none σ e); intros.
- specialize (H0 closed).
- match_case; [intros|tauto].
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- + case_eq (df_eval_backprop_deriv σ e grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (1 / d)%R)
- ; intros; simpl in H2.
- unfold df_eval_backprop_delta in H2.
- rewrite eqq1 in H2; simpl in H2.
- specialize (H2 vin vin).
- rewrite eqq0 in H2.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env (1 / d * 1)%R); intros.
- specialize (H3 closed); specialize (H2 H3).
- cut_to H2; try discriminate.
- invcs H.
- replace (1 / d)%R with (1 / d * 1)%R by lra.
- unfold lift in H2; match_option_in H2.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- ;intros.
- unfold lift; rewrite H2.
- f_equal; lra.
- + unfold lift in H.
- match_case_in H; intros.
- rewrite H2 in H; congruence.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros.
- now specialize (H3 closed).
- - Case "Abs"%string.
- specialize (H grad_env vin closed).
- simpl in *.
- generalize (eval_fully_closed_not_none σ e); intros.
- specialize (H0 closed).
- match_case; [intros|tauto].
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- + case_eq (df_eval_backprop_deriv σ e grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (sign d)%R)
- ; intros; simpl in H2.
- unfold df_eval_backprop_delta in H2.
- rewrite eqq1 in H2; simpl in H2.
- specialize (H2 vin vin).
- rewrite eqq0 in H2.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env (sign d * 1)%R); intros.
- specialize (H3 closed); specialize (H2 H3).
- cut_to H2; try discriminate.
- invcs H.
- replace (1 * (@sign floatish_R d))%R with ((@sign floatish_R d) * 1)%R by lra.
- unfold lift in H2; match_option_in H2.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- ;intros.
- unfold lift; rewrite H2.
- f_equal; lra.
- + unfold lift in H.
- match_case_in H; intros.
- rewrite H2 in H; congruence.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros.
- now specialize (H3 closed).
- - Case "Sign"%string.
- specialize (H grad_env vin closed).
- simpl in *.
- generalize (eval_fully_closed_not_none σ e); intros.
- specialize (H0 closed).
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- + case_eq (df_eval_backprop_deriv σ e grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (0)%R)
- ; intros H1; simpl in H1.
- unfold df_eval_backprop_delta in H1.
- rewrite eqq1 in H1; simpl in H1.
- specialize (H1 vin vin).
- rewrite eqq0 in H1.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env (0 * 1)%R); intros.
- specialize (H2 closed); specialize (H1 H2).
- cut_to H1; try discriminate.
- invcs H.
- replace (0)%R with (0 * 1)%R by lra.
- unfold lift in H1; match_option_in H1.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- ;intros.
- unfold lift; rewrite H1.
- f_equal; lra.
- + unfold lift in H.
- match_case_in H; intros.
- rewrite H1 in H; congruence.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros.
- now specialize (H2 closed).
- - Case "PSign"%string.
- specialize (H grad_env vin closed).
- simpl in *.
- generalize (eval_fully_closed_not_none σ e); intros.
- specialize (H0 closed).
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- + case_eq (df_eval_backprop_deriv σ e grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (0)%R)
- ; intros H1; simpl in H1.
- unfold df_eval_backprop_delta in H1.
- rewrite eqq1 in H1; simpl in H1.
- specialize (H1 vin vin).
- rewrite eqq0 in H1.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env (0 * 1)%R); intros.
- specialize (H2 closed); specialize (H1 H2).
- cut_to H1; try discriminate.
- invcs H.
- replace (0)%R with (0 * 1)%R by lra.
- unfold lift in H1; match_option_in H1.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- ;intros.
- unfold lift; rewrite H1.
- f_equal; lra.
- + unfold lift in H.
- match_case_in H; intros.
- rewrite H1 in H; congruence.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros.
- now specialize (H2 closed).
- - Case "Max"%string.
- destruct closed.
- specialize (H grad_env vin H1).
- specialize (H0 grad_env vin H2).
- simpl in *.
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H3 H1).
- match_case; [intros|tauto].
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- generalize (eval_fully_closed_not_none σ r); intros.
- specialize (H5 H2).
- case_eq (df_eval σ r); [intros|tauto].
- rewrite eqq0 in H.
- rewrite eqq0 in H0.
- destruct (Rle_dec d d1).
- + apply H0.
- + apply H.
- - Case "VectorDot"%string.
- destruct closed.
- specialize (H grad_env vin H1).
- simpl in *.
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H3 H1).
- match_case; [intros|tauto].
- generalize (eval_fully_closed_not_none σ r); intros.
- specialize (H5 H2).
- case_eq (df_eval σ r); [intros | tauto].
- replace (fun rv : R => (rv * 1)%R) with id.
- rewrite vmap_id; rewrite vmap_id.
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto].
- symmetry in H.
- generalize (backprop_deriv_fully_closed_not_none σ l grad_env d0); intros.
- specialize (H7 H1).
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq1 in H.
- specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H8.
- match_option.
- + match_option; [|tauto].
- generalize (backprop_deriv_fully_closed_not_none σ r d4 d); intros.
- specialize (H9 H2).
- unfold lift.
- match_option; [|tauto].
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat));intros.
- specialize (H10 vin).
- specialize (H0 d4 H10 H2).
- rewrite eqq0 in H0.
- match_option_in H0.
- symmetry in H0.
- unfold lift in H0.
- specialize (apply vectoro_to_ovector_forall_some_f H0); intros.
- simpl in H11.
- f_equal.
- unfold lift in H8.
- rewrite (split_subvar d4 d5 d1 d6); trivial.
- rewrite <- vsum_plus.
- generalize (df_eval_backprop_delta_by_unit_partvec σ l s grad_env d0
- (fun i => (d2 i * d0 i)%R)); intros.
- specialize (H12 H1 vin).
- generalize (df_eval_backprop_delta_by_unit_partvec σ r s d4 d
- (fun i => (d i * d3 i)%R)); intros.
- specialize (H13 H2 H10).
- cut_to H12.
- cut_to H13.
- * unfold df_eval_backprop_delta in H12.
- rewrite eqq1 in H12.
- unfold df_eval_backprop_delta in H13.
- rewrite eqq4 in H13.
- unfold lift in H12.
- unfold lift in H13.
- rewrite eqq2 in H12.
- rewrite eqq3 in H13.
- invcs H12; invcs H13.
- rewrite H14, H15; lra.
- * intros.
- replace (@scaleUnitVector (@float floatish_R) n i (d i) (IZR Z0)) with
- (scalarMult (DTVector n) (d i) (UnitVector n i)).
- rewrite scalarMult_backprop_grad_scalar with (grad_env1 := d4) (grad_env2:=d4);trivial.
- unfold df_eval_backprop_delta, lift.
- rewrite eqq4.
- specialize (H11 i).
- destruct i; simpl in H11.
- simpl_closed_backprop.
- f_equal.
- rewrite eqq5 in H11.
- invcs H11.
- rewrite H15.
- now unfold scalarMult; simpl.
- apply backprop_deriv_fully_closed_not_none; trivial.
- apply backprop_deriv_fully_closed_not_none; trivial.
- unfold scalarMult, UnitVector, scaleUnitVector; simpl.
- apply functional_extensionality; intros.
- destruct (equiv_dec (` x) (` i)); lra.
- * intros.
- replace (@scaleUnitVector (@float floatish_R) n i (d0 i) (IZR Z0)) with
- (scalarMult (DTVector n) (d0 i) (UnitVector n i)).
- rewrite scalarMult_backprop_grad_scalar with (grad_env1 := grad_env) (grad_env2:=grad_env);trivial.
- unfold df_eval_backprop_delta, lift.
- rewrite eqq1.
- specialize (H8 i).
- destruct i; simpl in H8.
- simpl_closed_backprop.
- f_equal.
- rewrite eqq5 in H8.
- invcs H8.
- rewrite H15.
- now unfold scalarMult; simpl; lra.
- apply backprop_deriv_fully_closed_not_none; trivial.
- apply backprop_deriv_fully_closed_not_none; trivial.
- unfold scalarMult, UnitVector, scaleUnitVector; simpl.
- apply functional_extensionality; intros.
- destruct (equiv_dec (` x) (` i)); lra.
- + generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros.
- now specialize (H9 H2).
- + generalize (eval_deriv_genvar_fully_closed_not_none σ l [mk_env_entry (s, DTfloat) 1%R]); intros.
- now specialize (H8 H1).
- + apply FunctionalExtensionality.functional_extensionality.
- intros; unfold id; lra.
- - Case "VectorSum"%string.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H; [|tauto|].
- + specialize (H closed).
- symmetry in H.
- generalize (vectoro_to_ovector_forall_some_f H); intros HH; clear H.
- replace (@lift (@df_env floatish_R) R
- (fun e : df_env => subvar (s, DTfloat) e d0)
- (df_eval_backprop_deriv σ v grad_env
- (ConstVector n 1%R)))
- with
- (df_eval_backprop_delta σ v (s, DTfloat) grad_env
- (ConstVector n 1%R)).
- rewrite (df_eval_backprop_delta_by_unit_parts _ _ _ _ d); trivial.
- * intros.
- specialize (HH i).
- unfold df_eval_backprop_delta.
- rewrite eqq0.
- replace (UnitVector n i) with
- (coerce
- (df_eval_backward_gen_top_obligation_2 UnitAnn (DTVector n) v n eq_refl i)
- (UnitVector n i)); trivial.
- destruct i.
- now simpl.
- * unfold df_eval_backprop_delta.
- now rewrite eqq0.
- + specialize (H closed).
- symmetry in H.
- specialize (vectoro_to_ovector_exists_None H); intros.
- destruct H0.
- unfold lift in e.
- match_option_in e.
- generalize (backprop_deriv_fully_closed_not_none
- σ v grad_env
- (coerce
- (df_eval_backward_gen_top_obligation_2
- UnitAnn (DTVector n) v n eq_refl x)
- (UnitVector n x))); intros.
- specialize (H0 closed); tauto.
- - Case "MatrixSum"%string.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H; [|tauto|].
- + specialize (H closed).
- symmetry in H.
- specialize (apply vectoro_to_ovector_forall_some_f H); intros HH; simpl in HH.
- replace (@lift (@df_env floatish_R) R
- (fun e : df_env => subvar (s, DTfloat) e d0)
- (df_eval_backprop_deriv σ v grad_env
- (ConstMatrix m n 1%R)))
- with
- (df_eval_backprop_delta σ v (s, DTfloat) grad_env
- (ConstMatrix m n 1%R)).
- rewrite (df_eval_backprop_delta_by_unit_parts_mat _ _ _ _ d); trivial.
- * intros.
- specialize (HH i).
- specialize (apply vectoro_to_ovector_forall_some_f HH); intros.
- specialize (H0 j); simpl in H0.
- unfold df_eval_backprop_delta.
- rewrite eqq0.
- replace (UnitMatrix m n i j) with
- (coerce
- (df_eval_backward_gen_top_obligation_3
- UnitAnn (DTMatrix m n) v m n eq_refl i j)
- (UnitMatrix m n i j)); trivial.
- destruct i; destruct j.
- now simpl.
- * unfold df_eval_backprop_delta.
- now rewrite eqq0.
- + specialize (H closed).
- symmetry in H.
- specialize (vectoro_to_ovector_exists_None H); intros.
- destruct H0.
- unfold lift in e.
- specialize (vectoro_to_ovector_exists_None e); intros.
- destruct H0.
- match_option_in e0.
- generalize (backprop_deriv_fully_closed_not_none
- σ v grad_env
- (coerce
- (df_eval_backward_gen_top_obligation_3
- UnitAnn (DTMatrix m n) v m n eq_refl x x0)
- (UnitMatrix m n x x0))); intros.
- specialize (H0 closed); tauto.
- - Case "VectorElem"%string.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H; [|tauto|].
- + specialize (H closed).
- symmetry in H.
- specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H.
- unfold lift.
- replace (fun k : {n' : nat | n' < n} => if equiv_dec (` k) (` i) then 1%R else 0%R)
- with
- (UnitVector n i).
- generalize (backprop_deriv_fully_closed_not_none
- σ l grad_env (UnitVector n i)); intros.
- specialize (H1 closed).
- specialize (H0 i).
- match_option; symmetry; [|tauto].
- destruct i; simpl in *.
- unfold lift in H0; f_equal.
- rewrite eqq1 in H0.
- now invcs H0.
- unfold UnitVector.
- apply FunctionalExtensionality.functional_extensionality; intros.
- trivial.
- + specialize (H closed).
- symmetry in H.
- specialize (vectoro_to_ovector_exists_None H); intros.
- destruct H0.
- unfold lift in e.
- match_option_in e.
- generalize (backprop_deriv_fully_closed_not_none
- σ l grad_env
- (coerce
- (df_eval_backward_gen_top_obligation_2
- UnitAnn (DTVector n) l n eq_refl x)
- (UnitVector n x))); intros.
- specialize (H0 closed); tauto.
- - Case "MatrixElem"%string.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H; [|tauto|].
- + specialize (H closed).
- symmetry in H.
- specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H.
- unfold lift.
- replace
- (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) =>
- if equiv_dec (` k1) (` i) then
- if equiv_dec (` k2) (` j) then 1%R else 0%R else 0%R)
- with (UnitMatrix m n i j).
- generalize (backprop_deriv_fully_closed_not_none
- σ l grad_env (UnitMatrix m n i j)); intros.
- specialize (H1 closed).
- specialize (H0 i).
- specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H2.
- specialize (H2 j).
- match_option; symmetry; [|tauto].
- destruct i; destruct j; simpl in *.
- unfold lift in H2; f_equal.
- rewrite eqq1 in H2.
- now invcs H2.
- unfold UnitMatrix.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- trivial.
- + specialize (H closed).
- symmetry in H.
- specialize (vectoro_to_ovector_exists_None H); intros.
- destruct H0.
- unfold lift in e.
- specialize (vectoro_to_ovector_exists_None e); intros.
- destruct H0.
- match_option_in e0.
- generalize (backprop_deriv_fully_closed_not_none
- σ l grad_env
- (coerce
- (df_eval_backward_gen_top_obligation_3
- UnitAnn (DTMatrix m n) l m n eq_refl x x0)
- (UnitMatrix m n x x0))); intros.
- specialize (H0 closed); tauto.
- - Case "MatrixVectorMult"%string.
- destruct closed.
- specialize (H grad_env vin H1).
- simpl in *.
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H3 H1).
- match_case; [intros|tauto].
- generalize (eval_fully_closed_not_none σ r); intros.
- specialize (H5 H2).
- case_eq (df_eval σ r); [intros | tauto].
- assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- match_option; [|tauto].
- assert (df_eval_deriv_genvar σ r [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- match_option; [|tauto].
- match_option; [|tauto].
- unfold lift.
- symmetry.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- simpl_closed_backprop.
- simpl_closed_backprop.
- f_equal.
- rewrite eqq1, eqq in H.
- unfold lift in H; symmetry in H.
- specialize (vectoro_to_ovector_forall_some_f H); intros.
- specialize (H0 env).
- simpler2.
- cut_to H0; try congruence.
- rewrite eqq0 in H0.
- rewrite (split_subvar env env0 d3 val); trivial.
- specialize (H9 i); simpl in H9.
- specialize (vectoro_to_ovector_forall_some_f H9); intros.
- replace (@vsum floatish_R n (fun j : {n' : nat | n' < n} => (d i j * d2 j + d1 i j * d0 j)%R))
- with ((vsum (fun j => (d i j * d2 j)%R)) + (vsum (fun j => (d1 i j * d0 j)%R)))%R
- ; [|rewrite vsum_plus; f_equal].
- unfold lift in H0; symmetry in H0.
- specialize (vectoro_to_ovector_forall_some_f H0); intros.
- simpl in H11; simpl in H10.
- destruct i; simpl in eqq2; simpl in eqq3.
- unfold matrix_vector_mult in eqq3; simpl in eqq3.
- unfold UnitVector in eqq2; simpl in eqq2.
- replace (fun i : {n' : nat | n' < n} =>
- (@vsum floatish_R m
- (fun j : {n' : nat | n' < m} =>
- (d j i * (@UnitVector floatish_R m (exist (fun n' : nat => (n' < m)%nat) x l0)
- j)%R)%R)))
- with (d (exist (fun n' : nat => (n' < m)%nat) x l0)) in eqq3.
- + generalize (df_eval_backprop_delta_by_unit_partvec
- σ r s env (d (exist (fun n' : nat => n' < m) x l0 ))
- (fun i => ((d (exist (fun n' : nat => (n' < m)%nat) x l0 ) i) * (d2 i))%R))
- ; intros.
- specialize (H12 H2).
- cut_to H12; try congruence.
- * unfold df_eval_backprop_delta in H12.
- rewrite eqq3, eqq4 in H12.
- unfold lift in H12.
- invcs H12.
- apply Rplus_eq_compat_l.
- generalize (df_eval_backprop_delta_by_unit_partmat
- σ l s grad_env
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- ((if equiv_dec (` i) x then 1 else 0) * d0 j)%R)
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- ((if equiv_dec (` i) x then 1 else 0) * (d1 i j) * (d0 j))%R))
- ; intros.
- specialize (H12 H1).
- cut_to H12; try congruence.
- -- unfold df_eval_backprop_delta in H12.
- rewrite eqq1 in H12.
- unfold lift in H12.
- rewrite eqq2 in H12.
- invcs H12.
- rewrite H15.
- replace
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- ((if equiv_dec (` i) x then 1 else 0) * d1 i j * d0 j)%R) with
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (UnitVector m (exist _ x l0) i * d1 i j * d0 j)%R).
- now rewrite msum_unitvector.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- now unfold UnitVector; simpl.
- -- intros.
- rewrite scalarMult_backprop_grad_scalar with (grad_env2:=grad_env)
- ;trivial; try congruence.
- unfold lift.
- unfold df_eval_backprop_delta.
- rewrite eqq1.
- unfold lift.
- specialize (vectoro_to_ovector_forall_some_f H); intros.
- specialize (H13 i); simpl in H13.
- specialize (vectoro_to_ovector_forall_some_f H13); intros.
- specialize (H15 j); intros; simpl in H15.
- destruct i; destruct j; simpl in H15.
- match_option_in H15.
- f_equal; simpl.
- destruct (equiv_dec x0 x).
- invcs H15.
- rewrite H17; lra.
- lra.
- apply backprop_deriv_fully_closed_not_none; trivial.
- apply backprop_deriv_fully_closed_not_none; trivial.
- * intros.
- replace (@scaleUnitVector (@float floatish_R) n i (d (@exist nat (fun n' : nat => lt n' m) x l0) i) (IZR Z0)) with
- (scalarMult (DTVector n) (d (exist (fun n' : nat => n' < m) x l0) i) (UnitVector n i)).
- rewrite scalarMult_backprop_grad_scalar with (grad_env1 := env) (grad_env2:=env);trivial; try congruence.
- unfold scalarMult; simpl.
- unfold lift; simpl.
- specialize (H11 i).
- destruct i; simpl in H11.
- match_option_in H11.
- unfold df_eval_backprop_delta.
- rewrite eqq4,eqq5.
- unfold lift; f_equal.
- now invcs H11.
- apply backprop_deriv_fully_closed_not_none; trivial.
- apply backprop_deriv_fully_closed_not_none; trivial.
- unfold scalarMult, scaleUnitVector.
- apply FunctionalExtensionality.functional_extensionality; intros.
- unfold UnitVector.
- destruct (equiv_dec (` x0) (` i)); simpl; lra.
- + apply FunctionalExtensionality.functional_extensionality; intros.
- now rewrite vsum_unitvector.
- - Case "MatrixVectorAdd"%string.
- destruct closed.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H; [|tauto|].
- + specialize (H H1).
- symmetry in H.
- specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H3.
- match_option; symmetry.
- * apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H3 i).
- specialize (apply vectoro_to_ovector_forall_some_f H3); intros; simpl in H4.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H4 i0).
- unfold lift in H4; unfold lift.
- match_option_in H4.
- unfold lift in H0.
- rewrite eqq1 in H0.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat))
- ;intros.
- specialize (H5 vin).
- specialize (H0 d2 H5 H2).
- match_option_in H0.
- symmetry in H0.
- specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H6.
- specialize (H6 i).
- destruct i; destruct i0; simpl; simpl in eqq2; simpl in H6.
- rewrite eqq2.
- match_option_in H6.
- generalize (list_env_iter_total_fun
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv
- σ r env
- (transpose
- (UnitMatrix m n (exist (fun n' : nat => n' < m) x l0)
- (exist (fun n' : nat => n' < n) x0 l1)) i))
- d2 (bounded_seq0 n)); intros.
- cut_to H7.
- -- case_eq (list_env_iter
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv
- σ r env
- (@transpose R m n
- (UnitMatrix m n (exist (fun n' : nat => n' < m) x l0)
- (exist (fun n' : nat => n' < n) x0 l1)) i))
- (Some d2)
- (bounded_seq0 n)); [intros |tauto].
- f_equal.
- rewrite (split_subvar d2 d5 d0 d3); trivial.
- invcs H4.
- invcs H6.
- generalize (list_env_iter_matvec_delta
- σ r s d2
- (exist (fun n' : nat => n' < n) x0 l1)
- (exist (fun n' : nat => n' < m) x l0) d3).
- intros.
- specialize (H4 eqq3 H2).
- rewrite eqq4 in H4.
- replace (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv
- σ r env
- (UnitMatrix n m (exist (fun n' : nat => n' < n) x0 l1)
- (exist (fun n' : nat => n' < m) x l0) i)) with
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv
- σ r env
- (@transpose R m n
- (UnitMatrix m n (exist (fun n' : nat => n' < m) x l0)
- (exist (fun n' : nat => n' < n) x0 l1)) i)) in H4.
- ++ rewrite H8 in H4.
- unfold lift in H4.
- invcs H4; lra.
- ++ apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- f_equal.
- unfold transpose, UnitMatrix.
- apply FunctionalExtensionality.functional_extensionality; intros.
- simpl.
- match_case; intros.
- match_case; intros.
- -- intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- * assert (df_eval_deriv_genvar σ r [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- - Case "MatrixMult"%string.
- destruct closed.
- specialize (H grad_env vin H1).
- simpl in *.
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H3 H1).
- match_case; [intros|tauto].
- generalize (eval_fully_closed_not_none σ r); intros.
- specialize (H5 H2).
- case_eq (df_eval σ r); [intros | tauto].
- assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- match_option; [|tauto].
- assert (df_eval_deriv_genvar σ r [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- match_option; [|tauto].
- match_option; [|tauto].
- unfold lift.
- symmetry.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- simpl_closed_backprop.
- simpl_closed_backprop.
- f_equal.
- rewrite eqq1, eqq in H.
- unfold lift in H; symmetry in H.
- specialize (vectoro_to_ovector_forall_some_f H); intros.
- specialize (H0 env).
- simpler2.
- cut_to H0; try congruence.
- rewrite eqq0 in H0.
- rewrite (split_subvar env env0 d3 val); trivial.
- specialize (H9 i); simpl in H9.
- specialize (vectoro_to_ovector_forall_some_f H9); intros.
- replace (@vsum floatish_R p (fun j => (d i j * d2 j i0 + d1 i j * d0 j i0)%R))
- with ((vsum (fun j => (d i j * d2 j i0)%R)) + (vsum (fun j => (d1 i j * d0 j i0)%R)))%R
- ; [|rewrite vsum_plus; f_equal].
- unfold lift in H0; symmetry in H0.
- specialize (vectoro_to_ovector_forall_some_f H0); intros.
- simpl in H10; simpl in H11.
- destruct i; destruct i0; simpl in eqq2; simpl in eqq3.
- unfold matrix_mult,UnitMatrix in eqq3; simpl in eqq3.
- unfold matrix_mult,UnitMatrix in eqq2; simpl in eqq2.
- generalize (df_eval_backprop_delta_by_unit_partmat
- σ l s grad_env
- (fun (i : {n' : nat | n' < m}) (k : {m' : nat | m' < p}) =>
- vsum
- (fun j : {n' : nat | n' < n} =>
- ((if equiv_dec (` i) x then
- if equiv_dec (` j) x0 then 1 else 0 else 0) * d0 k j)%R))
- (fun (i : {n' : nat | n' < m}) (k : {m' : nat | m' < p}) =>
- vsum
- (fun j : {n' : nat | n' < n} =>
- ((if equiv_dec (` i) x then
- if equiv_dec (` j) x0 then 1 else 0 else 0) *
- (d1 i k) * d0 k j)%R)))
- ; intros.
- specialize (H12 H1 vin).
- cut_to H12; try congruence.
- + unfold df_eval_backprop_delta in H12.
- rewrite eqq1 in H12.
- unfold lift in H12.
- rewrite eqq2 in H12.
- invcs H12.
- rewrite H14.
- assert (msum
- (fun (i : {n' : nat | (n' < m)%nat}) (k : {m' : nat | (m' < p)%nat}) =>
- vsum
- (fun j : {n' : nat | (n' < n)%nat} =>
- (if equiv_dec (` i) x then
- if equiv_dec (` j) x0 then 1 else 0 else 0)
- * d1 i k * d0 k j))%R =
- vsum
- (fun j : {n' : nat | (n' < p)%nat} =>
- d1 (exist (fun n' : nat => (n' < m)%nat) x l0) j *
- d0 j (exist (fun n' : nat => (n' < n)%nat) x0 l1))%R).
- * rewrite msum_transpose.
- unfold msum.
- f_equal.
- unfold transpose; simpl.
- replace (fun (i : {m' : nat | m' < p}) (j : {n' : nat | n' < m}) =>
- (@vsum floatish_R n
- (fun j0 : {n' : nat | n' < n} =>
- ((if equiv_dec (` j) x then
- if equiv_dec (` j0) x0 then 1 else 0
- else 0) * d1 j i * d0 i j0)%R))) with
- (fun (i : {m' : nat | m' < p}) (j : {n' : nat | n' < m}) =>
- if equiv_dec (` j) x then
- (d1 (exist _ x l0) i * d0 i (exist _ x0 l1))%R
- else 0%R).
- -- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vmap_nth.
- replace (fun j : {n' : nat | n' < m} =>
- if equiv_dec (` j) x
- then
- (d1 (exist (fun n' : nat => (n' < m)%nat) x l0) x1 *
- d0 x1 (exist (fun m' : nat => (m' < n)%nat) x0 l1))%R
- else 0%R) with
- (fun j : {n' : nat | n' < m} =>
- ((d1 (exist (fun n' : nat => (n' < m)%nat) x l0) x1 *
- d0 x1 (exist (fun m' : nat => (m' < n)%nat) x0 l1))%R *
- (@UnitVector floatish_R m (exist _ x l0) j))%R).
- now rewrite vsum_unitvector.
- apply FunctionalExtensionality.functional_extensionality; intros.
- unfold UnitVector; simpl.
- destruct (equiv_dec (` x2) x); lra.
- -- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- destruct ( equiv_dec (` x2) x).
- ++ replace
- (fun j0 : {n' : nat | n' < n} =>
- ((if equiv_dec (` j0) x0 then 1 else 0) * d1 x2 x1 * d0 x1 j0)%R)
- with
- (fun j0 : {n' : nat | n' < n} =>
- ((d1 x2 x1 * d0 x1 j0) * (UnitVector n (exist _ x0 l1) j0))%R).
- ** rewrite vsum_unitvector.
- red in e.
- subst.
- destruct x2.
- simpl.
- erewrite index_pf_irrel; eauto.
- ** apply FunctionalExtensionality.functional_extensionality; intros.
- unfold UnitVector; simpl.
- destruct (equiv_dec (` x3) x0); lra.
- ++ rewrite <- vsum_mult.
- lra.
- * rewrite H12.
- apply Rplus_eq_compat_r.
- generalize (df_eval_backprop_delta_by_unit_partmat
- σ r s env
- (fun (i : {n' : nat | n' < p}) (k : {m' : nat | m' < n}) =>
- vsum
- (fun j : {n' : nat | n' < m} =>
- (d j i *
- (if equiv_dec (` j) x then
- if equiv_dec (` k) x0 then 1 else 0 else 0))%R))
- (fun (i : {n' : nat | n' < p}) (k : {m' : nat | m' < n}) =>
- vsum
- (fun j : {n' : nat | n' < m} =>
- (d j i * d2 i k *
- (if equiv_dec (` j) x then
- if equiv_dec (` k) x0 then 1 else 0 else 0))%R)))
- ; intros.
- specialize (H13 H2).
- cut_to H13; try congruence.
- -- unfold df_eval_backprop_delta in H13.
- rewrite eqq4 in H13.
- unfold lift in H13.
- rewrite eqq3 in H13.
- invcs H13.
- rewrite H16.
- unfold msum; f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vmap_nth.
- replace
- (fun k : {m' : nat | m' < n} =>
- (@vsum floatish_R m
- (fun j : {n' : nat | n' < m} =>
- (d j x1 * d2 x1 k *
- (if equiv_dec (` j) x then
- if equiv_dec (` k) x0 then 1 else 0 else 0))%R))) with
- (fun k : {m' : nat | m' < n} =>
- (vsum
- (fun j =>
- (d j x1 * d2 x1 k *
- (UnitVector m (exist _ x l0) j))%R)
- * (UnitVector n (exist _ x0 l1) k))%R).
- ++ rewrite vsum_unitvector.
- rewrite vsum_unitvector.
- lra.
- ++ apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite Rmult_comm.
- rewrite vsum_mult.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- unfold UnitVector; simpl.
- destruct (equiv_dec (` x2) x0); destruct (equiv_dec (` x3) x); lra.
- -- intros.
- rewrite scalarMult_backprop_grad_scalar with (grad_env2:=env); trivial; try congruence.
- ++ unfold lift.
- unfold df_eval_backprop_delta.
- rewrite eqq4.
- unfold lift.
- specialize (H11 i); simpl in H11.
- specialize (vectoro_to_ovector_forall_some_f H11); intros.
- specialize (H15 j); intros; simpl in H15.
- destruct i; destruct j; simpl in H15.
- match_option_in H15.
- f_equal; simpl.
- invcs H15.
- rewrite H17.
- rewrite Rmult_comm.
- rewrite vsum_mult.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- ++ apply backprop_deriv_fully_closed_not_none; trivial.
- ++ apply backprop_deriv_fully_closed_not_none; trivial.
- + intros.
- rewrite scalarMult_backprop_grad_scalar with (grad_env2:=grad_env); trivial; try congruence.
- * unfold lift.
- unfold df_eval_backprop_delta.
- rewrite eqq1.
- unfold lift.
- specialize (vectoro_to_ovector_forall_some_f H i); intros; simpl in H11.
- specialize (vectoro_to_ovector_forall_some_f H13 j); intros; simpl in H14.
- destruct i; destruct j; simpl in H14.
- match_option_in H14.
- f_equal; simpl.
- invcs H14.
- rewrite H16.
- rewrite Rmult_comm.
- rewrite vsum_mult.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- * apply backprop_deriv_fully_closed_not_none; trivial.
- * apply backprop_deriv_fully_closed_not_none; trivial.
- - Case "VectorPlus"%string.
- destruct closed.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H; [|tauto|].
- + specialize (H H1).
- symmetry in H.
- specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H3.
- match_option; symmetry.
- * apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H3 i).
- unfold lift in H3; unfold lift.
- match_option_in H3.
- unfold lift in H0.
- rewrite eqq1 in H0.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat))
- ;intros.
- specialize (H4 vin).
- specialize (H0 d2 H4 H2).
- match_option_in H0.
- symmetry in H0.
- specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H5.
- specialize (H5 i).
- destruct i; simpl in H1; simpl; simpl in eqq2; simpl in H5.
- rewrite eqq2.
- generalize (backprop_deriv_fully_closed_not_none
- σ r d2
- (UnitVector n (exist (fun n' : nat => n' < n) x l0)))
- ; intros; specialize (H6 H2).
- match_option; [|tauto]; f_equal.
- rewrite eqq4 in H5; inversion H5.
- rewrite (split_subvar d2 d4 d0 d3); trivial.
- inversion H3; lra.
- * generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros.
- specialize (H4 H2).
- tauto.
- + specialize (H H1).
- symmetry in H.
- specialize (apply vectoro_to_ovector_exists_None H); intros.
- destruct H3.
- unfold lift in e.
- match_option_in e.
- generalize (backprop_deriv_fully_closed_not_none
- σ l grad_env
- (coerce
- (df_eval_backward_gen_top_obligation_2
- UnitAnn (DTVector n) l n eq_refl x)
- (UnitVector n x))); intros.
- specialize (H3 H1); tauto.
- - Case "VectorMinus"%string.
- destruct closed.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H; [|tauto|].
- + specialize (H H1).
- symmetry in H.
- specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H3.
- match_option; symmetry.
- * apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H3 i).
- unfold lift in H3; unfold lift.
- match_option_in H3.
- unfold lift in H0.
- rewrite eqq1 in H0.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat))
- ;intros.
- specialize (H4 vin).
- specialize (H0 d2 H4 H2).
- match_option_in H0.
- symmetry in H0.
- specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H5.
- specialize (H5 i).
- destruct i; simpl in H0; simpl; simpl in eqq2; simpl in H5.
- rewrite eqq2.
- generalize (backprop_deriv_fully_closed_not_none
- σ r d2
- (UnitVector n (exist (fun n' : nat => n' < n) x l0)))
- ; intros; specialize (H6 H2).
- generalize (scalarMult_backprop_grad_scalar
- σ r s d2 d2
- (UnitVector n (exist (fun n' : nat => n' < n) x l0))
- (-1)%R ); intros.
- specialize (H7 H4 H4).
- generalize (backprop_deriv_fully_closed_not_none
- σ r d2
- (scalarMult (DTVector n) (-1)%R
- (UnitVector n (exist (fun n' : nat => n' < n) x l0))))
- ; intros; specialize (H8 H2).
- specialize (H7 H8 H6).
- unfold df_eval_backprop_delta in H7; simpl in H7.
- rewrite eqq3 in H7.
- unfold lift in H7; simpl in H7.
- replace (fun i : {n' : nat | n' < n} =>
- (- (@UnitVector floatish_R n (exist (fun n' : nat => (n' < n)%nat) x l0)) i)%R)
- with
- (fun i : {n' : nat | n' < n} =>
- (-1 * UnitVector n (exist (fun n' : nat => (n' < n)%nat) x l0) i)%R).
- match_option; [|tauto]; f_equal.
- rewrite eqq4 in H7.
- rewrite (split_subvar d2 d4 d0 d3); trivial.
- rewrite H5 in H7.
- inversion H7.
- rewrite H10.
- inversion H3; lra.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- * generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros.
- specialize (H4 H2).
- tauto.
- + specialize (H H1).
- symmetry in H.
- specialize (apply vectoro_to_ovector_exists_None H); intros.
- destruct H3.
- unfold lift in e.
- match_option_in e.
- generalize (backprop_deriv_fully_closed_not_none
- σ l grad_env
- (coerce
- (df_eval_backward_gen_top_obligation_2 UnitAnn (DTVector n) l n eq_refl x)
- (UnitVector n x))); intros.
- specialize (H3 H1); tauto.
- - Case "MatrixPlus"%string.
- destruct closed.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H; [|tauto|].
- + specialize (H H1).
- symmetry in H.
- specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H1.
- match_option; symmetry.
- * apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H3 i).
- specialize (apply vectoro_to_ovector_forall_some_f H3); intros; simpl in H4.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H4 i0).
- unfold lift in H4; unfold lift.
- match_option_in H4.
- unfold lift in H0.
- rewrite eqq1 in H0.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat))
- ;intros.
- specialize (H5 vin).
- specialize (H0 d2 H5 H2).
- match_option_in H0.
- symmetry in H0.
- specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H6.
- specialize (H6 i).
- specialize (apply vectoro_to_ovector_forall_some_f H6); intros; simpl in H7.
- specialize (H7 i0).
- destruct i; destruct i0; simpl; simpl in eqq2; simpl in H7.
- rewrite eqq2.
- generalize (backprop_deriv_fully_closed_not_none
- σ r d2
- (UnitMatrix n m (exist (fun n' : nat => n' < n) x l0)
- (exist (fun n' : nat => n' < m) x0 l1)))
- ; intros; specialize (H8 H2).
- match_option; [|tauto]; f_equal.
- rewrite eqq4 in H7; inversion H7.
- rewrite (split_subvar d2 d4 d0 d3); trivial.
- inversion H4; lra.
- * generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros.
- now specialize (H4 H2).
- + specialize (H H1).
- symmetry in H.
- specialize (apply vectoro_to_ovector_exists_None H); intros.
- destruct H3.
- specialize (apply vectoro_to_ovector_exists_None e); intros.
- destruct H3.
- unfold lift in e0.
- match_option_in e0.
- generalize (backprop_deriv_fully_closed_not_none
- σ l grad_env
- (coerce
- (df_eval_backward_gen_top_obligation_3 UnitAnn (DTMatrix n m) l n m eq_refl x
- x0) (UnitMatrix n m x x0))); intros.
- specialize (H3 H1); tauto.
- - Case "MatrixMinus"%string.
- destruct closed.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H; [|tauto|].
- + specialize (H H1).
- symmetry in H.
- specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H1.
- match_option; symmetry.
- * apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H3 i).
- specialize (apply vectoro_to_ovector_forall_some_f H3); intros; simpl in H4.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H4 i0).
- unfold lift in H4; unfold lift.
- match_option_in H4.
- unfold lift in H0.
- rewrite eqq1 in H0.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat))
- ;intros.
- specialize (H5 vin).
- specialize (H0 d2 H5 H2).
- match_option_in H0.
- symmetry in H0.
- specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H6.
- specialize (H6 i).
- specialize (apply vectoro_to_ovector_forall_some_f H6); intros; simpl in H7.
- specialize (H7 i0).
- destruct i; destruct i0; simpl; simpl in eqq2; simpl in H7.
- rewrite eqq2.
- generalize (backprop_deriv_fully_closed_not_none
- σ r d2
- (UnitMatrix n m (exist (fun n' : nat => n' < n) x l0)
- (exist (fun n' : nat => n' < m) x0 l1)))
- ; intros; specialize (H8 H2).
-
- generalize (scalarMult_backprop_grad_scalar
- σ r s d2 d2
- (UnitMatrix n m (exist (fun n' : nat => n' < n) x l0)
- (exist (fun n' : nat => n' < m) x0 l1))
- (-1)%R ); intros.
- specialize (H9 H5 H5).
- generalize (backprop_deriv_fully_closed_not_none
- σ r d2
- (scalarMult (DTMatrix n m) (-1)%R
- (UnitMatrix n m (exist (fun n' : nat => n' < n) x l0)
- (exist (fun n' : nat => n' < m) x0 l1))))
- ; intros; specialize (H10 H2).
- specialize (H9 H10 H8).
- unfold df_eval_backprop_delta in H9; simpl in H9.
- rewrite eqq3 in H9.
- unfold lift in H9; simpl in H9.
- replace
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (-
- (@UnitMatrix floatish_R n m (exist (fun n' : nat => (n' < n)%nat) x l0)
- (exist (fun n' : nat => (n' < m)%nat) x0 l1)) i j)%R)
- with
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (-1 *
- UnitMatrix n m (exist (fun n' : nat => (n' < n)%nat) x l0)
- (exist (fun n' : nat => (n' < m)%nat) x0 l1) i j)%R).
- match_option; [|tauto]; f_equal.
- rewrite eqq4 in H9.
- rewrite (split_subvar d2 d4 d0 d3); trivial.
- rewrite H7 in H9.
- inversion H9.
- rewrite H12.
- inversion H4; lra.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- * generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros.
- now specialize (H4 H2).
- + specialize (H H1).
- symmetry in H.
- specialize (apply vectoro_to_ovector_exists_None H); intros.
- destruct H3.
- specialize (apply vectoro_to_ovector_exists_None e); intros.
- destruct H3.
- unfold lift in e0.
- match_option_in e0.
- generalize (backprop_deriv_fully_closed_not_none
- σ l grad_env
- (coerce
- (df_eval_backward_gen_top_obligation_3 UnitAnn (DTMatrix n m) l n m eq_refl x
- x0) (UnitMatrix n m x x0))); intros.
- specialize (H3 H1); tauto.
- - Case "VectorScalMult"%string.
- destruct closed.
- specialize (H grad_env vin H1).
- simpl in *.
- generalize (eval_fully_closed_not_none σ x); intros.
- specialize (H3 H1).
- match_case; [intros|tauto].
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H5 H2).
- case_eq (df_eval σ l); [intros|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- case_eq (df_eval_backprop_deriv σ x grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- match_option; unfold lift; symmetry.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- replace (@vsum
- floatish_R n
- (fun j : @sig nat (fun n' : nat => lt n' n) =>
- Rmult (d1 j)
- (@coerce (Vector R n) (Vector R n)
- (@df_eval_backward_gen_top_obligation_2
- floatish_R UnitAnn
- (DTVector n)
- (@VectorScalMult floatish_R UnitAnn n ann x l) n
- (@eq_refl definition_function_types (DTVector n)) i)
- (@UnitVector floatish_R n i) j)))
- with (d1 i).
- generalize (scalarMult_backprop_grad_scalar
- σ x s grad_env grad_env 1%R (d1 i)); intros; simpl in H7.
- unfold df_eval_backprop_delta in H7.
- rewrite eqq1 in H7; simpl in H7.
- specialize (H7 vin vin).
- rewrite eqq0 in H7.
- generalize (backprop_deriv_fully_closed_not_none
- σ x grad_env (d1 i * 1)%R ); intros.
- specialize (H8 H1); specialize (H7 H8).
- cut_to H7; try discriminate.
- invcs H.
- { specialize (H0 d3).
- cut_to H0;
- [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- | apply H2].
- unfold lift in H7; simpl in H7.
- match_option_in H7.
- replace (d1 i) with (d1 i * 1)%R by lra.
- rewrite eqq3.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)); intros.
- specialize (H vin).
- case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence].
- rewrite H9 in H0.
- unfold lift; simpl.
- unfold lift in H0.
- rewrite eqq2 in H0.
- generalize (scalarMult_backprop_grad_scalar σ l s d2 d3 (UnitVector n i) d)
- ; intros; simpl in H10.
- unfold df_eval_backprop_delta in H10; simpl in H10.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat));intros.
- specialize (H11 vin).
- case_eq (vartlookup d2 (s, DTfloat)); [intros | congruence].
- specialize (H10 H11 H).
- rewrite H9, H12 in H10.
- symmetry in H0.
- specialize (apply vectoro_to_ovector_forall_some_f H0); intros.
- specialize (H13 i); simpl in H13.
- generalize (backprop_deriv_fully_closed_not_none
- σ l d2
- (fun i0 : {n' : nat | n' < n} => (d * UnitVector n i i0)%R)); intros.
- specialize (H14 H2); specialize (H10 H14).
- generalize (backprop_deriv_fully_closed_not_none
- σ l d3 (UnitVector n i)); intros.
- specialize (H15 H2); specialize (H10 H15).
- unfold lift in H10.
- match_option_in H10; [|congruence]; f_equal.
- replace
- (fun j : @sig nat (fun n' : nat => lt n' n) =>
- Rmult d
- (@coerce (Vector R n) (Vector R n)
- (@df_eval_backward_gen_top_obligation_2
- floatish_R UnitAnn
- (DTVector n) (@VectorScalMult floatish_R UnitAnn n ann x l) n
- (@eq_refl definition_function_types (DTVector n)) i)
- (@UnitVector floatish_R n i) j))
- with
- (fun i0 : {n' : nat | n' < n} => (d * UnitVector n i i0)%R).
- rewrite eqq4.
- rewrite (split_subvar d2 d7 d0 d6); trivial; f_equal.
- match_option_in H10; inversion H10.
- match_option_in H13.
- invcs H7; invcs H13.
- rewrite H17.
- destruct i; simpl in eqq6.
- rewrite eqq6 in eqq5.
- invcs eqq5.
- lra.
- apply FunctionalExtensionality.functional_extensionality; intros.
- now destruct i; simpl.
- }
- + symmetry.
- destruct i; simpl.
- apply vsum_unitvector.
- + specialize (H0 d3).
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)); intros.
- specialize (H7 vin).
- specialize (H0 H7 H2).
- rewrite eqq2 in H0.
- unfold lift in H0.
- symmetry in H0.
- case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence].
- rewrite H8 in H0; simpl in H0.
- specialize (apply vectoro_to_ovector_exists_None H0); intros.
- destruct H9.
- generalize (backprop_deriv_fully_closed_not_none
- σ l d3
- (coerce
- (df_eval_backward_gen_top_obligation_2
- UnitAnn (DTVector n) l n eq_refl x0)
- (UnitVector n x0))); intros.
- specialize (H9 H2).
- match_option_in e.
- tauto.
- + unfold lift in H.
- generalize (backprop_deriv_fully_closed_not_none
- σ x grad_env 1%R); intros.
- specialize (H7 H1).
- match_option_in H.
- tauto.
- - Case "MatrixScalMult"%string.
- destruct closed.
- specialize (H grad_env vin H1).
- simpl in *.
- generalize (eval_fully_closed_not_none σ x); intros.
- specialize (H3 H1).
- match_case; [intros|tauto].
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H5 H2).
- case_eq (df_eval σ l); [intros|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- case_eq (df_eval_backprop_deriv σ x grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- match_option; unfold lift; symmetry.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- replace
- (@msum floatish_R n m
- (fun (i1 : @sig nat (fun n' : nat => lt n' n))
- (j : @sig nat (fun m' : nat => lt m' m)) =>
- Rmult (d1 i1 j)
- (@coerce (Matrix R n m) (Matrix R n m)
- (@df_eval_backward_gen_top_obligation_3 floatish_R UnitAnn
- (DTMatrix n m) (@MatrixScalMult floatish_R UnitAnn n m ann x l) n m
- (@eq_refl definition_function_types (DTMatrix n m)) i i0)
- (@UnitMatrix floatish_R n m i i0) i1 j)))
- with (d1 i i0).
- generalize (scalarMult_backprop_grad_scalar
- σ x s grad_env grad_env 1%R (d1 i i0)); intros; simpl in H7.
- unfold df_eval_backprop_delta in H7.
-
- rewrite eqq1 in H7; simpl in H7.
- specialize (H7 vin vin).
- rewrite eqq0 in H7.
- generalize (backprop_deriv_fully_closed_not_none
- σ x grad_env (d1 i i0 * 1)%R ); intros.
- specialize (H8 H1); specialize (H7 H8).
- cut_to H7; try discriminate.
- invcs H.
- { specialize (H0 d3).
- cut_to H0;
- [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- | apply H2].
- unfold lift in H7; simpl in H7.
- match_option_in H7.
- replace (d1 i i0) with (d1 i i0 * 1)%R by lra.
- rewrite eqq3.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)); intros.
- specialize (H vin).
- case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence].
- rewrite H9 in H0.
- unfold lift in H0.
- rewrite eqq2 in H0.
- generalize (scalarMult_backprop_grad_scalar σ l s d2 d3 (UnitMatrix n m i i0) d)
- ; intros; simpl in H10.
- unfold df_eval_backprop_delta in H10; simpl in H10.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat));intros.
- specialize (H11 vin).
- case_eq (vartlookup d2 (s, DTfloat)); [intros | congruence].
- specialize (H10 H11 H).
- rewrite H9, H12 in H10.
- symmetry in H0.
- specialize (apply vectoro_to_ovector_forall_some_f H0); intros.
- specialize (H13 i); simpl in H13.
- specialize (apply vectoro_to_ovector_forall_some_f H13); intros.
- specialize (H14 i0); simpl in H14.
- generalize (backprop_deriv_fully_closed_not_none
- σ l d2
- (fun (i1 : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (d * UnitMatrix n m i i0 i1 j)%R)); intros.
- specialize (H15 H2); specialize (H10 H15).
- generalize (backprop_deriv_fully_closed_not_none
- σ l d3 (UnitMatrix n m i i0)); intros.
- specialize (H16 H2); specialize (H10 H16).
- unfold lift in H10.
- match_option_in H10; [|congruence]; f_equal.
- replace
- (fun (i1 : @sig nat (fun n' : nat => lt n' n))
- (j : @sig nat (fun m' : nat => lt m' m)) =>
- Rmult
- (@coerce (Matrix R n m) (Matrix R n m)
- (@df_eval_backward_gen_top_obligation_3
- floatish_R UnitAnn
- (DTMatrix n m) (@MatrixScalMult floatish_R UnitAnn n m ann x l) n m
- (@eq_refl definition_function_types (DTMatrix n m)) i i0)
- (@UnitMatrix floatish_R n m i i0) i1 j) d)
- with
- (fun i1 j => (d * UnitMatrix n m i i0 i1 j)%R).
- rewrite eqq4.
- rewrite (split_subvar d2 d7 d0 d6); trivial; f_equal.
- match_option_in H10; inversion H10.
- match_option_in H14.
- rewrite H18.
- invcs H7; invcs H14.
- destruct i; destruct i0; simpl in eqq6.
- rewrite eqq6 in eqq5.
- invcs eqq5.
- rewrite H17.
- lra.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- destruct i; destruct i0; simpl.
- lra.
- }
- + symmetry.
- destruct i; destruct i0; simpl.
- apply msum_unitmatrix.
- + specialize (H0 d3).
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)); intros.
- specialize (H7 vin).
- specialize (H0 H7 H2).
- rewrite eqq2 in H0.
- unfold lift in H0.
- symmetry in H0.
- case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence].
- rewrite H8 in H0; simpl in H0.
- specialize (apply vectoro_to_ovector_exists_None H0); intros.
- destruct H9.
- specialize (apply vectoro_to_ovector_exists_None e); intros.
- destruct H9.
- generalize (backprop_deriv_fully_closed_not_none
- σ l d3
- (coerce
- (df_eval_backward_gen_top_obligation_3
- UnitAnn (DTMatrix n m) l n m eq_refl x0 x1)
- (UnitMatrix n m x0 x1))); intros.
- specialize (H9 H2).
- match_option_in e0.
- tauto.
- + unfold lift in H.
- generalize (backprop_deriv_fully_closed_not_none
- σ x grad_env 1%R); intros.
- specialize (H7 H1).
- match_option_in H.
- tauto.
- - Case "VectorApply"%string.
- destruct closed.
- generalize (eval_fully_closed_not_none (mk_env_entry (v, DTfloat) (0%R) :: nil) s0); intros.
- specialize (H4 H2).
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H5 H3).
- match_case; [intros|tauto].
- specialize (H1 grad_env vin H3).
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto].
- rewrite eqq1 in H1; simpl in H1.
- match_option
- ; rewrite eqq in H1; unfold lift in H1; symmetry in H1.
- + specialize (apply vectoro_to_ovector_forall_some_f H1); intros; simpl in H7.
- unfold lift; simpl.
- match_option.
- * specialize (vectoro_to_ovector_forall_some_f eqq0); intros; simpl in H8.
- symmetry.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H8 i).
- specialize (H7 i).
- destruct i.
- simpl.
- match_nested_case.
- -- assert ( df_eval_backprop_deriv σ l grad_env v1 <> None).
- apply backprop_deriv_fully_closed_not_none; trivial.
- match_option; [|tauto].
- f_equal.
- simpl in H7.
- match_option_in H8.
- specialize (vectoro_to_ovector_forall_some_f eqq2); intros.
- assert (H10c := H10).
- specialize (H10 (exist _ x l0)).
- rewrite vmap_nth in H10; simpl in H10.
- match_option_in H10; simpl in H10.
- unfold UnitVector in H10; simpl in H10.
- destruct (equiv_dec x x); [|congruence].
- invcs H8; rewrite H12.
- assert (v1 =
- scalarMult (DTVector n) d4
- (UnitVector n (exist (fun n' : nat => n' < n) x l0))).
- ++ unfold scalarMult; simpl.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H10c x0); simpl in H10c.
- rewrite vmap_nth in H10c; simpl in H10c.
- match_option_in H10c.
- invcs H10c.
- destruct x0.
- unfold UnitVector; simpl.
- destruct (equiv_dec x0 x).
- ** red in e0.
- subst.
- erewrite index_pf_irrel in eqq6.
- rewrite eqq5 in eqq6.
- invcs eqq6; lra.
- ** lra.
- ++ replace (1 * d4)%R with d4 in H10 by lra.
- generalize (scalarMult_backprop_grad_scalar
- σ l s grad_env grad_env
- (UnitVector n (exist (fun n' : nat => n' < n) x l0)) d4)
- ; intros.
- simpl in H11; cut_to H11; trivial; try congruence.
- ** unfold df_eval_backprop_delta in H11.
- rewrite eqq1 in H11.
- unfold lift in H11; simpl in H11.
- rewrite H8 in eqq3.
- unfold scalarMult in eqq3; simpl in eqq3.
- match_option_in H7.
- rewrite eqq3, eqq6 in H11.
- invcs H11; rewrite H14.
- invcs H7; rewrite H11.
- generalize
- (df_eval_deriv_genvar_same
- [mk_env_entry (v, DTfloat)
- (d (exist (fun n' : nat => n' < n) x l0) )]
- s0 v); simpl; intros.
- specialize (H7 H H2).
- unfold lift, df_eval_deriv_gen_top in H7; simpl in H7.
- rewrite <- H12.
- unfold mk_genvar_env in eqq4; simpl in eqq4.
- rewrite eqq4, eqq5 in H7.
- invcs H7; lra.
- ** apply backprop_deriv_fully_closed_not_none; trivial.
- ** apply backprop_deriv_fully_closed_not_none; trivial.
- -- apply vectoro_to_ovector_exists_None in eqq2; destruct eqq2.
- rewrite vmap_nth in e; simpl in e.
- assert (df_eval [mk_env_entry (v, DTfloat) (d x0)]
- (df_deriv s0 (v, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- apply fully_closed_deriv; trivial.
- match_option_in e; tauto.
- * apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0.
- assert (df_eval_deriv_genvar [mk_env_entry (v, DTfloat) (d x)] s0
- [mk_env_entry (v, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- match_option_in e; tauto.
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- - Case "MatrixApply"%string.
- destruct closed.
- generalize (eval_fully_closed_not_none (mk_env_entry (v, DTfloat) (0%R) :: nil) s0); intros.
- specialize (H4 H2).
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H5 H3).
- match_case; [intros|tauto].
- specialize (H1 grad_env vin H3).
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto].
- rewrite eqq1 in H1; simpl in H1.
- match_option
- ; rewrite eqq in H1; unfold lift in H1; symmetry in H1.
- + specialize (apply vectoro_to_ovector_forall_some_f H1); intros; simpl in H7.
- unfold lift; simpl.
- match_option.
- * specialize (apply vectoro_to_ovector_forall_some_f eqq0); intros; simpl in H8.
- symmetry.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H8 i).
- specialize (H7 i).
- match_nested_case.
- -- assert ( df_eval_backprop_deriv σ l grad_env m1 <> None).
- apply backprop_deriv_fully_closed_not_none; trivial.
- match_option; [|tauto].
- f_equal.
- simpl in H7.
- specialize (vectoro_to_ovector_forall_some_f H8); intros.
- specialize (vectoro_to_ovector_forall_some_f H7); intros.
- specialize (H10 i0); simpl in H10.
- match_option_in H10.
- specialize (vectoro_to_ovector_forall_some_f eqq2); intros.
- assert (H12c := H12).
- specialize (H12 i); simpl in H12.
- specialize (vectoro_to_ovector_forall_some_f H12); intros.
- specialize (H13 i0); simpl in H13; unfold mmap in H13.
- rewrite vmap_nth in H13; simpl in H13.
- rewrite vmap_nth in H13; simpl in H13.
- destruct i; destruct i0; simpl in H13.
- simpl in *.
- unfold matrix_zip in H13.
- rewrite vmap_nth in H13; simpl in H13.
- match_option_in H13; simpl in H13.
- unfold UnitMatrix in H13; simpl in H13.
- destruct (equiv_dec x x); [|congruence].
- destruct (equiv_dec x0 x0); [|congruence].
- invcs H13; invcs H10.
- assert (m1 =
- scalarMult (DTMatrix m n) d4
- (UnitMatrix m n
- (exist (fun n' : nat => n' < m) x l0)
- (exist (fun n' : nat => n' < n) x0 l1))).
- ++ unfold scalarMult; simpl.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H12c x1); simpl in H12c.
- specialize (vectoro_to_ovector_forall_some_f H12c); intros.
- specialize (H10 x2); simpl in H10; unfold mmap in H10.
-
- rewrite vmap_nth in H10; simpl in H10.
- rewrite vmap_nth in H10; simpl in H10.
- unfold matrix_zip in H10.
- rewrite vmap_nth in H10; simpl in H10.
- match_option_in H10.
- invcs H10.
- destruct x1; destruct x2.
- unfold UnitMatrix; simpl.
- destruct (equiv_dec x1 x).
- ** red in e1; subst.
- destruct (equiv_dec x2 x0).
- --- red in e1; subst.
- rewrite index_pf_irrel with (pf2 := l0) in eqq6.
- rewrite index_pf_irrel with (pf1 := l3) (pf2 := l1) in eqq6.
- rewrite eqq5 in eqq6.
- invcs eqq6; lra.
- --- lra.
- ** lra.
- ++ replace (1 * d4)%R with d4 in H15 by lra.
- generalize (scalarMult_backprop_grad_scalar
- σ l s grad_env grad_env
- (UnitMatrix m n (exist (fun n' : nat => n' < m) x l0)
- (exist (fun n' : nat => n' < n) x0 l1)) d4)
- ; intros.
- simpl in H13; cut_to H13; trivial; try congruence.
- ** unfold df_eval_backprop_delta in H13.
- rewrite eqq1 in H13.
- unfold lift in H13; simpl in H13.
- rewrite H10 in eqq3.
- unfold scalarMult in eqq3; simpl in eqq3.
- specialize (H11 (exist (fun n' : nat => n' < n) x0 l1)).
- simpl in H11.
- match_option_in H11.
- rewrite eqq3, eqq6 in H13.
-
- generalize
- (df_eval_deriv_genvar_same
- [mk_env_entry (v, DTfloat)
- (d (exist (fun n' : nat => n' < m) x l0)
- (exist (fun n' : nat => n' < n) x0 l1))]
- s0 v).
- simpl; intros.
- specialize (H16 H H2).
- unfold lift, df_eval_deriv_gen_top in H16; simpl in H16.
- unfold mk_genvar_env in eqq4; simpl in eqq4.
- rewrite eqq4, eqq5 in H16.
- invcs H13; rewrite H18.
- invcs H16.
- invcs H11; rewrite H15; lra.
- ** apply backprop_deriv_fully_closed_not_none; trivial.
- ** apply backprop_deriv_fully_closed_not_none; trivial.
- -- unfold matrixo_to_omatrix in eqq2.
- apply vectoro_to_ovector_exists_None in eqq2; destruct eqq2.
- apply vectoro_to_ovector_exists_None in e; destruct e.
- unfold mmap in e.
- rewrite vmap_nth in e; simpl in e.
- rewrite vmap_nth in e; simpl in e.
- rewrite matrix_zip_m_n in e.
- destruct i; destruct i0.
- simpl in e.
- assert (df_eval [mk_env_entry (v, DTfloat)
- (d x x0)]
- (df_deriv s0 (v, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- apply fully_closed_deriv; trivial.
- match_option_in e; tauto.
- * unfold matrixo_to_omatrix in eqq0.
- apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0.
- apply vectoro_to_ovector_exists_None in e; destruct e.
- assert (df_eval_deriv_genvar [mk_env_entry (v, DTfloat) (d x x0)] s0
- [mk_env_entry (v, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- match_option_in e; tauto.
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- - Case "VLossfun"%string.
- destruct closed.
- generalize (eval_fully_closed_not_none
- (mk_env_entry (v1, DTfloat) (0%R) ::
- (mk_env_entry (v2, DTfloat) (0%R) :: nil)) s0); intros.
- specialize (H4 H2).
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H5 H3).
- match_case; [intros|tauto].
- specialize (H1 grad_env vin H3).
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto].
- rewrite eqq1 in H1; simpl in H1.
- match_option
- ; rewrite eqq in H1; unfold lift in H1; symmetry in H1.
- + specialize (vectoro_to_ovector_forall_some_f H1); intros; simpl in H7.
- unfold lift; simpl.
- match_nested_case.
- * specialize (vectoro_to_ovector_forall_some_f eqq0); intros; simpl in H8.
- symmetry.
- match_nested_case.
- -- specialize (vectoro_to_ovector_forall_some_f eqq2); intros; simpl in H9.
- assert ( df_eval_backprop_deriv σ l grad_env v0 <> None).
- apply backprop_deriv_fully_closed_not_none; trivial.
- match_option; [|tauto].
- unfold snd in d1; simpl in d1.
- assert (forall i : {n' : nat | n' < n},
- df_eval [mk_env_entry (v1, DTfloat) (d i); mk_env_entry (v2, DTfloat) (r i)]
- (df_deriv s0 (v1, DTfloat))
- = Some (v0 i)).
- intros ; specialize (H9 i)
- ; rewrite vmap_nth in H9
- ; simpl in H9
- ; match_option_in H9
- ;inversion H9 ;f_equal; lra.
- generalize (df_eval_backprop_delta_by_unit_partvec σ l s grad_env v0 v).
- intros.
- specialize (H12 H3 vin).
- cut_to H12.
- ++ unfold df_eval_backprop_delta, lift in H12.
- now rewrite eqq1, eqq3 in H12.
- ++ intros.
- replace (@scaleUnitVector (@float floatish_R) n i (v0 i) (IZR Z0)) with
- (scalarMult (DTVector n) (v0 i) (UnitVector n i)) by
- (unfold scalarMult, scaleUnitVector, UnitVector;
- apply FunctionalExtensionality.functional_extensionality; intros;
- simpl; destruct (equiv_dec (` x) (` i)); lra).
- rewrite scalarMult_backprop_grad_scalar with (grad_env2 := grad_env); trivial.
- ** unfold df_eval_backprop_delta.
- rewrite eqq1.
- specialize (H7 i); simpl in H7.
- specialize (H8 i); simpl in H8.
- specialize (H9 i); simpl in H9.
- generalize
- (df_eval_deriv_genvar_same
- [mk_env_entry (v1, DTfloat) (d i); mk_env_entry (v2, DTfloat) (r i)]
- s0 v1); simpl; intros.
- specialize (H13 H H2).
- unfold df_eval_deriv_gen_top in H13; simpl in H13.
- unfold lift in H13; simpl in H13.
- match_option_in H13.
- --- unfold mk_genvar_env in H8; simpl in H8.
- rewrite eqq4 in H8.
- specialize (H11 i).
- destruct i; simpl in H7.
- match_option_in H7.
- unfold lift; f_equal.
- unfold scalarMult; simpl.
- invcs H8.
- invcs H7.
- rewrite H14.
- rewrite <- H13 in H11.
- invcs H11.
- lra.
- --- assert (df_eval_deriv_genvar
- [mk_env_entry (v1, DTfloat) (d i);
- mk_env_entry (v2, DTfloat) (r i)] s0
- [mk_env_entry (v1, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- ** apply backprop_deriv_fully_closed_not_none; trivial.
- ** apply backprop_deriv_fully_closed_not_none; trivial.
- -- apply vectoro_to_ovector_exists_None in eqq2.
- destruct eqq2.
- rewrite vmap_nth in e; simpl in e.
- match_option_in e.
- assert (df_eval [mk_env_entry (v1, DTfloat) (d x);
- mk_env_entry (v2, DTfloat) (r x)]
- (df_deriv s0 (v1, DTfloat)) <> None).
- apply eval_fully_closed_not_none; simpl.
- apply fully_closed_deriv; trivial.
- tauto.
- * apply vectoro_to_ovector_exists_None in eqq0.
- destruct eqq0.
- match_option_in e.
- assert (df_eval_deriv_genvar
- [mk_env_entry (v1, DTfloat) (d x);
- mk_env_entry (v2, DTfloat) (r x)] s0
- [mk_env_entry (v1, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- - Case "MLossfun"%string.
- destruct closed.
- generalize (eval_fully_closed_not_none
- (mk_env_entry (v1, DTfloat) (0%R) ::
- (mk_env_entry (v2, DTfloat) (0%R) :: nil)) s0); intros.
- specialize (H4 H2).
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H5 H3).
- match_case; [intros|tauto].
- specialize (H1 grad_env vin H3).
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto].
- rewrite eqq1 in H1; simpl in H1.
- match_option
- ; rewrite eqq in H1; unfold lift in H1; symmetry in H1.
- + specialize (vectoro_to_ovector_forall_some_f H1); intros; simpl in H7.
- unfold lift; simpl.
- match_nested_case.
- * specialize (vectoro_to_ovector_forall_some_f eqq0); intros; simpl in H8.
- symmetry.
- match_nested_case.
- -- specialize (vectoro_to_ovector_forall_some_f eqq2); intros; simpl in H9.
- assert ( df_eval_backprop_deriv σ l grad_env m1 <> None).
- apply backprop_deriv_fully_closed_not_none; trivial.
- match_option; [|tauto].
- unfold snd in d1; simpl in d1.
- assert (forall (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}),
- match
- df_eval [mk_env_entry (v1, DTfloat) (d i j);
- mk_env_entry (v2, DTfloat) (r i j)]
- (df_deriv s0 (v1, DTfloat))
- with
- | Some se => Some (1 * se / IZR (Z.of_nat n))%R
- | None => None
- end = Some (m1 i j)); intros.
- ++ specialize (H9 i).
- specialize (vectoro_to_ovector_forall_some_f H9); intros.
- specialize (H11 j); simpl in H11; unfold mmap in H11.
- do 2 rewrite vmap_nth in H11; simpl in H11.
- unfold matrix_zip, vector_zip in H11; simpl in H11.
- rewrite vmap_nth in H11; simpl in H11.
- match_option; rewrite eqq4 in H11; apply H11.
- ++ generalize (df_eval_backprop_delta_by_unit_partmat
- σ l s grad_env m1
- (mmap (fun u => u / IZR (Z.of_nat n))%R m0)); intros.
- specialize (H12 H3 vin).
- cut_to H12.
- ** unfold df_eval_backprop_delta, lift in H12.
- rewrite eqq1, eqq3 in H12.
- replace ((@msum floatish_R m n m0) / IZR (Z.of_nat n))%R with
- (msum (mmap (fun u : R => (u / IZR (Z.of_nat n))%R) m0)).
- --- apply H12.
- --- now rewrite msum_mmap_div_denom.
- ** intros.
- rewrite scalarMult_backprop_grad_scalar with (grad_env2 := grad_env)
- ; trivial.
- --- unfold df_eval_backprop_delta.
- rewrite eqq1.
- specialize (H7 i); simpl in H7.
- specialize (H8 i); simpl in H8.
- specialize (H11 i j); simpl in H11.
- specialize (vectoro_to_ovector_forall_some_f H7); intros.
- specialize (H13 j); simpl in H13.
- specialize (vectoro_to_ovector_forall_some_f H8); intros.
- specialize (H14 j); simpl in H14.
- generalize
- (df_eval_deriv_genvar_same
- [mk_env_entry (v1, DTfloat) (d i j);
- mk_env_entry (v2, DTfloat) (r i j)]
- s0 v1); simpl; intros.
- specialize (H15 H H2).
- unfold df_eval_deriv_gen_top in H15; simpl in H15.
- unfold lift in H15; simpl in H15.
- match_option_in H15.
- +++ unfold mk_genvar_env in H14; simpl in H14.
- rewrite eqq4 in H14.
- destruct i; destruct j; simpl in H13.
- match_option_in H13.
- unfold lift; f_equal.
- invcs H13.
- invcs H14.
- rewrite H17.
- unfold mmap.
- rewrite vmap_nth.
- rewrite vmap_nth.
- rewrite <- H16.
- rewrite <- H15 in H11.
- invcs H11.
- lra.
- +++ assert (df_eval_deriv_genvar
- [mk_env_entry (v1, DTfloat) (d i j);
- mk_env_entry (v2, DTfloat) (r i j)] s0
- [mk_env_entry (v1, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- --- apply backprop_deriv_fully_closed_not_none; trivial.
- --- apply backprop_deriv_fully_closed_not_none; trivial.
- -- unfold matrixo_to_omatrix in eqq2.
- apply vectoro_to_ovector_exists_None in eqq2.
- destruct eqq2.
- apply vectoro_to_ovector_exists_None in e.
- destruct e.
- unfold mmap in e.
- rewrite vmap_nth in e; simpl in e.
- rewrite vmap_nth in e; simpl in e.
- unfold matrix_zip,vector_zip in e; simpl in e.
- rewrite vmap_nth in e.
- match_option_in e.
- assert (df_eval [mk_env_entry (v1, DTfloat) (d x x0);
- mk_env_entry (v2, DTfloat) (r x x0)]
- (df_deriv s0 (v1, DTfloat)) <> None).
- apply eval_fully_closed_not_none; simpl.
- apply fully_closed_deriv; trivial.
- tauto.
- * unfold matrixo_to_omatrix in eqq0.
- apply vectoro_to_ovector_exists_None in eqq0.
- destruct eqq0.
- apply vectoro_to_ovector_exists_None in e.
- destruct e.
- match_option_in e.
- assert (df_eval_deriv_genvar
- [mk_env_entry (v1, DTfloat) (d x x0);
- mk_env_entry (v2, DTfloat) (r x x0)] s0
- [mk_env_entry (v1, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- Qed.
-
-(*
-Tactic Notation "DefinedFunction_cases" tactic(first) ident(c) :=
- first;
- [ Case_aux c "Number"%string
- | Case_aux c "Constant"%string
- | Case_aux c "DVector"%string
- | Case_aux c "DMatrix"%string
- | Case_aux c "Var"%string
- | Case_aux c "Plus"%string
- | Case_aux c "Minus"%string
- | Case_aux c "Times"%string
- | Case_aux c "Divide"%string
- | Case_aux c "Square"%string
- | Case_aux c "Exp"%string
- | Case_aux c "Log"%string
- | Case_aux c "Abs"%string
- | Case_aux c "Sign"%string
- | Case_aux c "PSign"%string
- | Case_aux c "Max"%string
- | Case_aux c "VectorDot"%string
- | Case_aux c "VectorSum"%string
- | Case_aux c "MatrixSum"%string
- | Case_aux c "VectorElem"%string
- | Case_aux c "MatrixElem"%string
- | Case_aux c "MatrixVectorMult"%string
- | Case_aux c "MatrixVectorAdd"%string
- | Case_aux c "MatrixMult"%string
- | Case_aux c "VectorPlus"%string
- | Case_aux c "VectorMinus"%string
- | Case_aux c "MatrixPlus"%string
- | Case_aux c "MatrixMinus"%string
- | Case_aux c "VectorScalMult"%string
- | Case_aux c "MatrixScalMult"%string
- | Case_aux c "VectorApply"%string
- | Case_aux c "MatrixApply"%string
- | Case_aux c "VLossfun"%string
- | Case_aux c "MLossfun"%string].
-
-
- Lemma tree_backpropeq_complete_gen {T} (env gradenv : df_env)
- (dfexpr : DefinedFunction EvalAnn T) (grad : definition_function_types_interp T) :
- forall (x : SubVar),
- let xvar := (x, DTfloat) in
- vartlookup gradenv (x,DTfloat) <> None ->
- match df_eval_tree_deriv env dfexpr xvar,
- backprop_lookup (Some gradenv) xvar,
- backprop_lookup (df_eval_tree_backprop_deriv env dfexpr gradenv grad) xvar
- with
- | Some dval, Some bval0, Some bval1 => (dval*grad + bval0)%R = bval1
- | None, _, None => True
- | _, _, _ => False
- end.
- Proof.
- simpl.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case.
-
- - Case "Number"%string.
- intros _ _ grad gradenv xinn.
- destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra.
- - Case "Constant"%string.
- intros _ _ grad gradenv xinn.
- destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra.
- - Case "Var"%string.
- intros sv _ grad gradenv xinn.
- case_eq (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto].
- destruct (var_dec x sv); simpl.
- + subst.
- rewrite H; simpl.
- rewrite lookup_update.
- destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (sv, DTfloat)); [| congruence].
- unfold addvar; simpl.
- rewrite H.
- lra.
- + destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (x, DTfloat)); [congruence | ].
- case_eq (vartlookup gradenv (sv, DTfloat)); simpl; intros.
- * rewrite lookup_update_neq by congruence.
- rewrite H.
- lra.
- * rewrite H.
- lra.
- - Case "Plus"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- case_eq (df_eval_tree_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_tree_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl grad gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env l gradenv grad)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr grad ge').
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' grad) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_tree_backprop_deriv env l gradenv grad ); simpl; trivial; intros.
- apply IHr.
- apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl grad gradenv xinn).
- case_eq (df_eval_tree_backprop_deriv env l gradenv grad); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Minus"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- case_eq (df_eval_tree_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_tree_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl grad gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env l gradenv grad)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr (- grad)%R ge').
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (- grad)%R) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_tree_backprop_deriv env l gradenv grad ); simpl; trivial; intros.
- apply IHr.
- apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl grad gradenv xinn).
- case_eq (df_eval_tree_backprop_deriv env l gradenv grad); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Times"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- case_eq (df_eval_tree_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_tree_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl (get_annotation r * grad)%R gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr (get_annotation l * grad)%R ge').
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (get_annotation l * grad)%R) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R ); simpl; trivial; intros.
- apply IHr.
- apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl (get_annotation r * grad)%R gradenv xinn).
- case_eq (df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Divide"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- case_eq (df_eval_tree_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_tree_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl (grad / get_annotation r)%R gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr (- get_annotation l / ((get_annotation r) * (get_annotation r)) * grad)%R ge').
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (- get_annotation l / ((get_annotation r) * (get_annotation r)) * grad)%R) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r)%R ); simpl; trivial; intros.
- apply IHr.
- apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl (grad / get_annotation r)%R gradenv xinn).
- case_eq (df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r )%R); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Square"%string.
- intros _ e IHe grad gradenv xinn.
-
- specialize (IHe (2 * (get_annotation e) * grad)%R gradenv xinn).
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
-
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env e gradenv (2 * (get_annotation e) * grad)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
- - Case "Exp"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe (grad * exp (get_annotation e))%R gradenv xinn).
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env e gradenv (grad * exp (get_annotation e))%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
-
- - Case "Log"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe (grad / get_annotation e)%R gradenv xinn).
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env e gradenv (grad / get_annotation e)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
-
- - Case "Abs"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe (grad * (sign (get_annotation e)))%R gradenv xinn).
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env e gradenv (grad * (sign (get_annotation e)))%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
- - Case "Sign"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe 0%R gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (df_eval_tree_backprop_deriv env e gradenv 0%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- simpl.
- replace (de * 0)%R with (0)%R in IHe by lra.
- replace (0 * grad)%R with (0)%R by lra.
- apply IHe.
- - Case "PSign"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe 0%R gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (df_eval_tree_backprop_deriv env e gradenv 0%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- simpl.
- replace (de * 0)%R with (0)%R in IHe by lra.
- replace (0 * grad)%R with (0)%R by lra.
- apply IHe.
- - Case "Max"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- specialize (IHl grad gradenv xinn).
- specialize (IHr grad gradenv xinn).
- destruct (Rle_dec (get_annotation l) (get_annotation r)); simpl.
- destruct (df_eval_tree_deriv env r (x, DTfloat)); simpl; trivial.
- destruct (df_eval_tree_deriv env l (x, DTfloat)); simpl; trivial.
- Qed.
-*)
-
-End real_pfs.
-
-(*
- Section FreeVariablesExample.
- (* We need to open the string scope in order to use "a" as a string. *)
- Open Scope string_scope.
- Theorem ex1 : (df_free_variables (Plus (Var "a") (Var "b"))) = "a"::"b"::nil.
- Proof.
- (* Reflexivity doesn't need syntactically identical things on either side of =.
- * It suffices that the left-hand side beta-reduced to the right-hand side. *)
- reflexivity.
- Qed.
-
- Close Scope string_scope.
-
- End FreeVariablesExample.
-*)
diff --git a/coq/NeuralNetworks/Gen_NN.v b/coq/NeuralNetworks/Gen_NN.v
deleted file mode 100644
index b8d25204..00000000
--- a/coq/NeuralNetworks/Gen_NN.v
+++ /dev/null
@@ -1,827 +0,0 @@
-Require Import String.
-Require Import RelationClasses.
-Require Import EquivDec.
-Require Import Streams.
-Require Import List.
-Require Import ListAdd.
-Require Import Rbase Rtrigo Rpower Rbasic_fun.
-Require Import DefinedFunctions.
-Require Import Vector.
-Require Import Lra Lia.
-
-Require Import Floatish Utils.
-
-Section GenNN.
-
- Context {floatish_impl:floatish}.
-
- Local Open Scope float.
-
- Definition UnitDefinedFunction := @DefinedFunction floatish_impl UnitAnn.
-
- Record FullNN : Type := mkNN { ldims : list nat; param_var : SubVar;
- f_activ : UnitDefinedFunction DTfloat; f_loss : UnitDefinedFunction DTfloat }.
-
- Definition mkSubVarVector (v : SubVar) (n : nat) : UnitDefinedFunction (DTVector n) :=
- DVector tt (fun i => Var (Sub v (proj1_sig i), DTfloat) tt).
-
- Definition mkVarVector (v : SubVar) (n : nat) : UnitDefinedFunction (DTVector n) :=
- Var (v, DTVector n) tt.
-
- Definition mkSubVarMatrix (v : SubVar) (n m : nat) : UnitDefinedFunction (DTMatrix n m) :=
- DMatrix tt (fun i j => Var (Sub (Sub v (proj1_sig i)) (proj1_sig j), DTfloat) tt).
-
- Definition mkVarMatrix (v : SubVar) (n m : nat) : UnitDefinedFunction (DTMatrix n m) :=
- Var (v, DTMatrix n m) tt.
-
- Definition unique_var {Ann} (df : DefinedFunction Ann DTfloat) : option var_type :=
- let fv := nodup vart_dec (df_free_variables df) in
- match fv with
- | nil => None
- | v :: nil => Some v
- | _ => None
- end.
-
- Definition activation (df : UnitDefinedFunction DTfloat) (vec : list (UnitDefinedFunction DTfloat)) : option (list (UnitDefinedFunction DTfloat)) :=
- match unique_var df with
- | Some (v, DTfloat) => Some (map (fun dfj => df_subst df (v, DTfloat) dfj) vec)
- | _ => None
- end.
-
- Definition create_activation_fun (df : UnitDefinedFunction DTfloat) : option (UnitDefinedFunction DTfloat -> UnitDefinedFunction DTfloat) :=
- match unique_var df with
- | Some (v, DTfloat) => Some (fun val => df_subst df (v, DTfloat) val)
- | _ => None
- end.
-
- Definition mkNN2 (n1 n2 n3 : nat) (ivar wvar : SubVar) (f_activ : UnitDefinedFunction DTfloat) (f_activ_var : SubVar) : (UnitDefinedFunction (DTVector n3)) :=
- let mat1 := mkSubVarMatrix (Sub wvar 1) n2 n1 in
- let mat2 := mkSubVarMatrix (Sub wvar 2) n3 n2 in
- let ivec := mkVarVector ivar n1 in
- let N1 := VectorApply tt f_activ_var f_activ (MatrixVectorMult tt mat1 ivec) in
- VectorApply tt f_activ_var f_activ (MatrixVectorMult tt mat2 N1).
-
- Definition mkVarNN2 (n1 n2 n3 : nat) (ivar wvar : SubVar) (f_activ : UnitDefinedFunction DTfloat) (f_activ_var : SubVar) : (UnitDefinedFunction (DTVector n3) * list var_type) :=
- let mat1 := (Sub wvar 1, DTMatrix n2 n1) in
- let mat2 := (Sub wvar 2, DTMatrix n3 n2) in
- let ivec := Var (ivar, DTVector n1) tt in
- let N1 := VectorApply tt f_activ_var f_activ (MatrixVectorMult tt (Var mat1 tt) ivec) in
- (VectorApply tt f_activ_var f_activ (MatrixVectorMult tt (Var mat2 tt) N1),
- mat1 :: mat2 :: nil).
-
- Definition mkVarMatNN2 (n1 n2 n3 nsamp: nat) (ivar wvar : SubVar) (f_activ : UnitDefinedFunction DTfloat) (f_activ_var : SubVar) : (UnitDefinedFunction (DTMatrix n3 nsamp) * list var_type) :=
- let mat1 := (Sub wvar 1, DTMatrix n2 n1) in
- let mat2 := (Sub wvar 2, DTMatrix n3 n2) in
- let imat := Var (ivar, DTMatrix n1 nsamp) tt in
- let N1 := MatrixApply tt f_activ_var f_activ (MatrixMult tt (Var mat1 tt) imat) in
- (MatrixApply tt f_activ_var f_activ (MatrixMult tt (Var mat2 tt) N1),
- mat1 :: mat2 :: nil).
-
- Definition mkNN_bias_step (n1 n2 : nat) (ivec : UnitDefinedFunction (DTVector n1))
- (mat : UnitDefinedFunction (DTMatrix n2 n1))
- (bias : UnitDefinedFunction (DTVector n2))
- (f_activ_var : SubVar) (f_activ : UnitDefinedFunction DTfloat)
- : UnitDefinedFunction (DTVector n2) :=
- VectorApply tt f_activ_var f_activ (VectorPlus tt (MatrixVectorMult tt mat ivec) bias).
-
- Definition mkNN2_bias (n1 n2 n3 : nat) (ivar wvar : SubVar) (f_activ : UnitDefinedFunction DTfloat) (f_activ_var : SubVar) : UnitDefinedFunction (DTVector n3) :=
- let mat1 := mkSubVarMatrix (Sub wvar 1) n2 n1 in
- let b1 := mkSubVarVector (Sub wvar 1) n2 in
- let mat2 := mkSubVarMatrix (Sub wvar 2) n3 n2 in
- let b2 := mkSubVarVector (Sub wvar 2) n3 in
- let ivec := mkVarVector ivar n1 in
- let N1 := mkNN_bias_step n1 n2 ivec mat1 b1 f_activ_var f_activ in
- mkNN_bias_step n2 n3 N1 mat2 b2 f_activ_var f_activ.
-
- Definition mkNN2_Var_bias (n1 n2 n3 : nat) (ivar wvar bvar : SubVar) (f_activ : UnitDefinedFunction DTfloat) (f_activ_var : SubVar) : (UnitDefinedFunction (DTVector n3) * list var_type) :=
- let mat1 := (Sub wvar 1, DTMatrix n2 n1) in
- let b1 := (Sub bvar 1, DTVector n2) in
- let mat2 := (Sub wvar 2, DTMatrix n3 n2) in
- let b2 := (Sub bvar 2, DTVector n3) in
- let ivec := Var (ivar, DTVector n1) tt in
- let N1 := mkNN_bias_step n1 n2 ivec (Var mat1 tt) (Var b1 tt) f_activ_var f_activ in
- (mkNN_bias_step n2 n3 N1 (Var mat2 tt) (Var b2 tt) f_activ_var f_activ,
- mat1 :: b1 :: mat2 :: b2 :: nil).
-
- Definition mkNN_Mat_bias_step (n1 n2 nsamp : nat) (imat : UnitDefinedFunction (DTMatrix n1 nsamp))
- (mat : UnitDefinedFunction (DTMatrix n2 n1))
- (bias : UnitDefinedFunction (DTVector n2))
- (f_activ_var : SubVar) (f_activ : UnitDefinedFunction DTfloat)
- : UnitDefinedFunction (DTMatrix n2 nsamp) :=
- MatrixApply tt f_activ_var f_activ (MatrixVectorAdd tt (MatrixMult tt mat imat) bias).
-
- Definition mkNN2_Var_Mat_bias (n1 n2 n3 nsamp: nat) (ivar wvar bvar: SubVar) (f1_activ f2_activ : UnitDefinedFunction DTfloat) (f1_activ_var f2_activ_var : SubVar) : (UnitDefinedFunction (DTMatrix n3 nsamp) * list var_type) :=
- let mat1 := (Sub wvar 1, DTMatrix n2 n1) in
- let b1 := (Sub bvar 1, DTVector n2) in
- let mat2 := (Sub wvar 2, DTMatrix n3 n2) in
- let b2 := (Sub bvar 2, DTVector n3) in
- let imat := Var (ivar, DTMatrix n1 nsamp) tt in
- let N1 := mkNN_Mat_bias_step n1 n2 nsamp imat (Var mat1 tt) (Var b1 tt) f1_activ_var f1_activ in
- (mkNN_Mat_bias_step n2 n3 nsamp N1 (Var mat2 tt) (Var b2 tt) f2_activ_var f2_activ,
- mat1 :: b1 :: mat2 :: b2 :: nil).
-
- Lemma vector_float_map_last_rewrite {B nvlist1 n2 v n1} :
- (DTVector (last ((@domain _ B) nvlist1) n2)) =
- (DTVector (last (domain((n2, v) :: nvlist1)) n1)).
- Proof.
- rewrite domain_cons.
- rewrite last_cons.
- reflexivity.
- Qed.
-
- Lemma matrix_float_map_last_rewrite {B nvlist1 n2 v n1 nsamp} :
- (DTMatrix (last ((@domain _ B) nvlist1) n2) nsamp) =
- (DTMatrix (last (domain((n2, v) :: nvlist1)) n1) nsamp).
- Proof.
- rewrite domain_cons.
- rewrite last_cons.
- reflexivity.
- Qed.
-
- Fixpoint mkNN_gen_0 (n1:nat) (nvlist : list (nat * SubVar))
- (ivec : (UnitDefinedFunction (DTVector n1)))
- (f_activ_var : SubVar ) (f_activ : UnitDefinedFunction DTfloat) :
- UnitDefinedFunction (DTVector (last (domain nvlist) n1))
-:=
- match nvlist with
- | nil => ivec
- | cons (n2,v) nvlist1 =>
- let mat := mkSubVarMatrix v n2 n1 in
- let b := mkSubVarVector v n2 in
- let N := mkNN_bias_step n1 n2 ivec mat b f_activ_var f_activ in
- eq_rect _ UnitDefinedFunction (mkNN_gen_0 n2 nvlist1 N f_activ_var f_activ) _ vector_float_map_last_rewrite
- end.
-
- Fixpoint mkNN_Var_gen_0 (n1:nat) (nvlist : list (nat * SubVar))
- (ivec : (UnitDefinedFunction (DTVector n1)))
- (f_activ_var : SubVar ) (f_activ : UnitDefinedFunction DTfloat) :
- UnitDefinedFunction (DTVector (last (domain nvlist) n1))
-:=
- match nvlist with
- | nil => ivec
- | cons (n2,v) nvlist1 =>
- let mat := mkVarMatrix v n2 n1 in
- let b := mkVarVector v n2 in
- let N := mkNN_bias_step n1 n2 ivec mat b f_activ_var f_activ in
- eq_rect _ UnitDefinedFunction (mkNN_Var_gen_0 n2 nvlist1 N f_activ_var f_activ) _ vector_float_map_last_rewrite
- end.
-
- Fixpoint mkNN_Mat_Var_gen_0 (nsamp n1:nat) (nvlist : list (nat * SubVar))
- (imat : (UnitDefinedFunction (DTMatrix n1 nsamp)))
- (f_activ_var : SubVar ) (f_activ : UnitDefinedFunction DTfloat) :
- UnitDefinedFunction (DTMatrix (last (domain nvlist) n1) nsamp)
-:=
- match nvlist with
- | nil => imat
- | cons (n2,v) nvlist1 =>
- let mat := mkVarMatrix v n2 n1 in
- let b := mkVarVector v n2 in
- let N := mkNN_Mat_bias_step n1 n2 nsamp imat mat b f_activ_var f_activ in
- eq_rect _ UnitDefinedFunction (mkNN_Mat_Var_gen_0 nsamp n2 nvlist1 N f_activ_var f_activ) _ matrix_float_map_last_rewrite
- end.
-
- Program Definition mkNN_gen (n1:nat) (nlist : list nat) (ivar wvar f_activ_var : SubVar)
- (f_activ : UnitDefinedFunction DTfloat) :
- UnitDefinedFunction (DTVector (last nlist n1)) :=
- let vlist := map (fun i => Sub wvar i) (seq 1 (length nlist)) in
- let ivec := mkVarVector ivar n1 in
- eq_rect _ UnitDefinedFunction
- (mkNN_gen_0 n1 (combine nlist vlist) ivec f_activ_var f_activ) _ _.
- Next Obligation.
- f_equal.
- f_equal.
- rewrite combine_domain_eq; trivial.
- now rewrite map_length, seq_length.
- Qed.
-
- Program Definition mkNN_Var_gen (n1:nat) (nlist : list nat) (ivar wvar f_activ_var : SubVar)
- (f_activ : UnitDefinedFunction DTfloat) :
- UnitDefinedFunction (DTVector (last nlist n1)) :=
- let vlist := map (fun i => Sub wvar i) (seq 1 (length nlist)) in
- let ivec := mkVarVector ivar n1 in
- eq_rect _ UnitDefinedFunction
- (mkNN_Var_gen_0 n1 (combine nlist vlist) ivec f_activ_var f_activ) _ _.
- Next Obligation.
- f_equal.
- f_equal.
- rewrite combine_domain_eq; trivial.
- now rewrite map_length, seq_length.
- Qed.
-
- Program Definition mkNN_Mat_Var_gen (nsamp n1:nat) (nlist : list nat) (ivar wvar f_activ_var : SubVar)
- (f_activ : UnitDefinedFunction DTfloat) :
- UnitDefinedFunction (DTMatrix (last nlist n1) nsamp) :=
- let vlist := map (fun i => Sub wvar i) (seq 1 (length nlist)) in
- let imat := mkVarMatrix ivar n1 nsamp in
- eq_rect _ UnitDefinedFunction
- (mkNN_Mat_Var_gen_0 nsamp n1 (combine nlist vlist) imat f_activ_var f_activ) _ _.
- Next Obligation.
- f_equal.
- f_equal.
- rewrite combine_domain_eq; trivial.
- now rewrite map_length, seq_length.
- Qed.
-
- Definition softmax {n:nat} (NN : UnitDefinedFunction (DTVector n)) : UnitDefinedFunction (DTVector n) :=
- let expvar := Name "expvar" in
- let NNexp := VectorApply tt expvar (Exp tt (Var (expvar, DTfloat) tt)) NN in
- let NNexpscale := Divide tt (Number tt 1) (VectorSum tt NNexp) in
- VectorScalMult tt NNexpscale NNexp.
-
- Definition L2loss (nnvar ovar : SubVar) : UnitDefinedFunction DTfloat :=
- Square tt ( Minus tt (Var (nnvar, DTfloat) tt) (Var (ovar, DTfloat) tt) ).
-
- Definition L1loss (nnvar ovar : SubVar) : UnitDefinedFunction DTfloat :=
- Abs tt (Minus tt (Var (nnvar, DTfloat) tt) (Var (ovar, DTfloat) tt)).
-
- Definition Sigmoid (var : SubVar) : UnitDefinedFunction DTfloat :=
- Divide tt (Number tt 1)
- (Plus tt (Number tt 1) (Exp tt (Minus tt (Number tt 0) (Var (var, DTfloat) tt)))).
-
- Definition CrossEntropy (nnvar ovar : SubVar) : UnitDefinedFunction DTfloat :=
- let nnvar' := Var (nnvar, DTfloat) tt in
- let ovar' := Var (ovar, DTfloat) tt in
- Minus tt (Times tt (Minus tt ovar' (Number tt 1)) (Log tt (Minus tt (Number tt 1) nnvar')))
- (Times tt ovar' (Log tt nnvar')).
-
- Record testcases : Type := mkTest {ninput: nat; noutput: nat; ntest: nat;
- datavec : Vector ((Vector float ninput) * (Vector float noutput)) ntest}.
-
- Definition NNinstance1samp {ninput noutput : nat} (ivar : SubVar)
- (f_loss : UnitDefinedFunction DTfloat)
- (f_loss_NNvar f_loss_outvar : SubVar)
- (NN : UnitDefinedFunction (DTVector noutput)) (σ:df_env)
- (data: (Vector float ninput) * (Vector float noutput))
- : df_env * (UnitDefinedFunction DTfloat) :=
- let ipair := mk_env_entry (ivar, DTVector ninput) (fst data) in
- (cons ipair σ, VLossfun tt f_loss_NNvar f_loss_outvar f_loss NN (snd data)).
-
- Definition NNinstancebatch {ninput nsamp noutput : nat} (ivar : SubVar)
- (f_loss : UnitDefinedFunction DTfloat)
- (f_loss_NNvar f_loss_outvar : SubVar)
- (NN : UnitDefinedFunction (DTMatrix noutput nsamp)) (σ:df_env)
- (data: (Matrix float ninput nsamp) * (Matrix float noutput nsamp))
- : df_env * (UnitDefinedFunction DTfloat) :=
- let ipair := mk_env_entry (ivar, DTMatrix ninput nsamp) (fst data) in
- (cons ipair σ, MLossfun tt f_loss_NNvar f_loss_outvar f_loss NN (snd data)).
-
- Definition EvalNNinstance1samp {ninput noutput : nat} (ivar : SubVar)
- (f_loss : UnitDefinedFunction DTfloat)
- (f_loss_NNvar f_loss_outvar : SubVar)
- (NN : UnitDefinedFunction (DTVector noutput)) (σ:df_env)
- (data: (Vector float ninput) * (Vector float noutput))
- : option float :=
- let ipair := mk_env_entry (ivar, DTVector ninput) (fst data) in
- df_eval (cons ipair σ) (VLossfun tt f_loss_NNvar f_loss_outvar f_loss NN (snd data)).
-
- Definition EvalNNinstancebatch {ninput nsamp noutput : nat} (ivar : SubVar)
- (f_loss : UnitDefinedFunction DTfloat)
- (f_loss_NNvar f_loss_outvar : SubVar)
- (NN : UnitDefinedFunction (DTMatrix noutput nsamp)) (σ:df_env)
- (data: (Matrix float ninput nsamp) * (Matrix float noutput nsamp))
- : option float :=
- let ipair := mk_env_entry (ivar, DTMatrix ninput nsamp) (fst data) in
- df_eval (cons ipair σ) (MLossfun tt f_loss_NNvar f_loss_outvar f_loss NN (snd data)).
-
- (*
- Lemma NNinstance_unique_var (n1 n2 n3 : nat) (ivar : SubVar) (f_loss : DefinedFunction DTfloat)
- (NN2 : DefinedFunction (DTVector n3)) (inputs : (list float))
- (outputs : Vector float n3) (v:SubVar) :
- unique_var f_loss = Some v ->
- NNinstance n1 n2 n3 ivar f_loss NN2 inputs outputs =
- Some (
- let ipairs := (list_prod (map (fun n => (Sub ivar n)) (seq 1 n1))
- (map Number inputs)) in
- let losses := VectorMinus (df_subst_list NN2 ipairs)
- (DVector (vmap Number outputs)) in
- (VectorSum (VectorApply v f_loss losses))
- ).
- Proof.
- unfold NNinstance.
- intros.
- rewrite H.
- reflexivity.
- Qed.
-
- Lemma NNinstance_None (n1 n2 n3 : nat) (ivar : SubVar) (f_loss : DefinedFunction DTfloat)
- (NN2 : DefinedFunction (DTVector n3)) (inputs : (list float))
- (outputs : Vector float n3) :
- unique_var f_loss = None ->
- NNinstance n1 n2 n3 ivar f_loss NN2 inputs outputs = None.
- Proof.
- unfold NNinstance.
- intros.
- now rewrite H.
- Qed.
- *)
-
-(* Local Existing Instance floatish_interval. *)
-
- Definition lookup_list (σ:df_env) (lvar : list SubVar) : option (list float) :=
- listo_to_olist (map (fun v => (vartlookup σ (v, DTfloat)):option float) lvar).
-
- Definition combine_with {A:Type} {B:Type} {C:Type} (f: A -> B -> C )
- (lA : list A) (lB : list B) : list C :=
- map (fun '(a, b) => f a b) (combine lA lB).
-
- Definition combine3_with {A:Type} {B:Type} {C:Type} {D:Type} (f: A -> B -> C -> D)
- (lA : list A) (lB : list B) (lC : list C) : list D :=
- map (fun '(a, bc) => f a (fst bc) (snd bc)) (combine lA (combine lB lC)).
-
- Fixpoint streamtake (n : nat) {A : Type} (st : Stream A) : (list A) * (Stream A) :=
- match n with
- | 0 => (nil, st)
- | S n' => let rst := streamtake n' (Streams.tl st) in
- ((Streams.hd st)::(fst rst), snd rst)
- end.
-
- Lemma streamtake_n (n : nat) (A : Type) (st : Stream A) :
- length (fst (streamtake n st)) = n.
- Proof.
- generalize st.
- induction n.
- reflexivity.
- intros.
- simpl.
- f_equal.
- specialize IHn with (st := Streams.tl st0).
- apply IHn.
- Qed.
-
- Fixpoint env_update_first (l:df_env) (an:env_entry_type) : df_env
- := match l with
- | nil => nil
- | fv::os => if (projT1 an) == (projT1 fv) then an::os
- else fv::(env_update_first os an)
- end.
-
- Definition env_update_list (l up:df_env) : df_env
- := fold_left (env_update_first) up l.
-
- Definition optimize_step
- (step : nat) (df : UnitDefinedFunction DTfloat) (σ:df_env) (lvar : list SubVar)
- (noise_st : Stream float) : (option df_env)*(Stream float) :=
- let lvart:list var_type := (map (fun v => (v, DTfloat)) lvar) in
- let ogradvec := df_eval_gradient σ df lvart in
- let alpha := 1 / (FfromZ (Z.of_nat (S step))) in
- let '(lnoise, nst) := streamtake (length lvar) noise_st in
- let olvals := lookup_list σ lvar in
- (match (ogradvec, olvals) with
- | (Some gradvec, Some lvals) =>
- Some (env_update_list
- σ
- (map (fun '(v,e) => mk_env_entry (v, DTfloat) (e:float))
- (combine lvar (combine3_with
- (fun val grad noise => val - alpha*(grad + noise))
- lvals gradvec lnoise))))
- | (_, _) => None
- end, nst).
-
- Fixpoint get_noise_vector (n: nat) (noise_st: Stream float) :
- (Vector float n) * (Stream float) :=
- match n with
- | 0 => (vnil, noise_st)
- | S n' =>
- let noise := Streams.hd noise_st in
- let nst := Streams.tl noise_st in
- let '(vec, nst') := get_noise_vector n' nst in
- (vcons noise vec, nst')
- end.
-
- Fixpoint get_noise_matrix (n m: nat) (noise_st: Stream float) :
- (Matrix float n m) * (Stream float) :=
- let '(vec, nst) := get_noise_vector m noise_st in
- match n with
- | 0 => (fun i => vec, nst)
- | S n' =>
- let '(mat, nst') := get_noise_matrix n' m nst in
- (vcons vec mat, nst')
- end.
-
- Definition get_noise (t:definition_function_types) (noise_st:Stream float) :
- (definition_function_types_interp t) * (Stream float) :=
- match t with
- | DTfloat => (Streams.hd noise_st, Streams.tl noise_st)
- | DTVector n => get_noise_vector n noise_st
- | DTMatrix m n => get_noise_matrix m n noise_st
- end.
-
- Program Definition update_val_gradenv (grad_env:df_env) (x:var_type) (alpha:float) :
- definition_function_types_interp (snd x) ->
- definition_function_types_interp (snd x) ->
- definition_function_types_interp (snd x) :=
- (match snd x as y return snd x = y ->
- definition_function_types_interp y ->
- definition_function_types_interp y ->
- definition_function_types_interp y with
- | DTfloat => fun pf val noise =>
- match vartlookup grad_env x with
- | Some grad => val - alpha * (coerce _ grad + noise)
- | _ => val
- end
- | DTVector n => fun pf val noise =>
- match vartlookup grad_env x with
- | Some grad =>
- fun i =>
- (val i) - alpha * (((coerce _ grad) i) + (noise i))
- | _ => val
- end
- | DTMatrix m n => fun pf val noise =>
- match vartlookup grad_env x with
- | Some grad =>
- fun i j =>
- (val i j) - alpha * (((coerce _ grad) i j) + (noise i j))
- | _ => val
- end
- end) (eq_refl _)
- .
- Next Obligation.
- rewrite pf; reflexivity.
- Qed.
- Next Obligation.
- rewrite pf; reflexivity.
- Qed.
- Next Obligation.
- rewrite pf; reflexivity.
- Qed.
-
-
-
- Definition update_entry (entry: env_entry_type) (grad_env:df_env) (alpha:float)
- (noise_st : Stream float) : (env_entry_type*Stream float) :=
- let x := projT1 entry in
- let val := projT2 entry in
- let '(noise, nst) := get_noise (snd x) noise_st in
- (mk_env_entry x (update_val_gradenv grad_env x alpha val noise), nst).
-
- Fixpoint list_arg_iter {A B} (f: A -> B -> A * B)
- (l:list A) (b: B) : (list A)*B :=
- match l with
- | nil => (l, b)
- | a :: l' =>
- let '(na, b') := f a b in
- let '(nl, nb) := list_arg_iter f l' b' in
- (na::nl, nb)
- end.
-
- Definition harmonic_lr (step : nat) : float := 1 / (FfromZ (Z.of_nat (S step))).
-
- Definition optimize_step_backprop
- (step : nat) (df : UnitDefinedFunction DTfloat) (σ:df_env) (lr : nat -> float)
- (noise_st : Stream float) (dvars : list var_type)
- : (option df_env)*(Stream float) :=
- match df_eval_backprop_deriv σ df (gradenv_init dvars) 1 with
- | Some gradenv =>
- let alpha := lr step in
- let '(env, nst) := list_arg_iter (fun a b => update_entry a gradenv alpha b) σ noise_st in
- (Some env, nst)
- | _ => (None, noise_st)
- end.
-
- Definition optimize_step_tree_backprop
- (step : nat) (df : UnitDefinedFunction DTfloat) (σ:df_env) (lr : nat -> float)
- (noise_st : Stream float) (dvars : list var_type)
- : (option df_env)*(Stream float) :=
- match df_eval_tree σ df with
- | Some df_tree =>
- match df_eval_tree_backprop_deriv (* σ not-needed *) nil df_tree (gradenv_init dvars) 1 with
- | Some gradenv =>
- let alpha := lr step in
- let '(env, nst) := list_arg_iter (fun a b => update_entry a gradenv alpha b) σ noise_st in
- (Some env, nst)
- | _ => (None, noise_st)
- end
- | _ => (None, noise_st)
- end.
-
- Fixpoint optimize_steps
- (start count:nat) (df : UnitDefinedFunction DTfloat) (σ:df_env) (lvar : list SubVar)
- (noise_st : Stream float) : (option df_env)*(Stream float) :=
- match count with
- | 0 => (Some σ, noise_st)
- | S n =>
- match optimize_step start df σ lvar noise_st with
- | (Some σ', noise_st') => optimize_steps (S start) n df σ' lvar noise_st'
- | (None, noise_st') => (None, noise_st')
- end
- end.
-
- Fixpoint optimize_steps_backprop
- (start count:nat) (df : UnitDefinedFunction DTfloat) (σ:df_env) (lr : nat -> float)
- (noise_st : Stream float) (dvars : list var_type)
- : (option df_env)*(Stream float) :=
- match count with
- | 0 => (Some σ, noise_st)
- | S n =>
- match optimize_step_backprop start df σ lr noise_st dvars with
- | (Some σ', noise_st') => optimize_steps_backprop (S start) n df σ' lr noise_st' dvars
- | (None, noise_st') => (None, noise_st')
- end
- end.
-
- Fixpoint optimize_steps_tree_backprop
- (start count:nat) (df : UnitDefinedFunction DTfloat) (σ:df_env) (lr : nat -> float)
- (noise_st : Stream float) (dvars : list var_type)
- : (option df_env)*(Stream float) :=
- match count with
- | 0 => (Some σ, noise_st)
- | S n =>
- match optimize_step_tree_backprop start df σ lr noise_st dvars with
- | (Some σ', noise_st') => optimize_steps_tree_backprop (S start) n df σ' lr noise_st' dvars
- | (None, noise_st') => (None, noise_st')
- end
- end.
-
-
-Definition Fmax (a b:float) : float :=
- if Fgt a b then a else b.
-
-Definition Fmin (a b:float) : float :=
- if Flt a b then a else b.
-
-Definition vmax {n : nat} (vec : Vector float n)
- := vector_fold_right1 Fmax (FfromZ 0) id vec.
-
-Definition vmin {n : nat} (vec : Vector float n)
- := vector_fold_right1 Fmin (FfromZ 0) id vec.
-
-Definition nrows {A} (l : list (list A)) := List.length l.
-
-Definition ncols {A} (l : list (list A)) :=
- match l with
- | nil => 0%nat
- | r :: _ => List.length r
- end.
-
-Definition normalizeIntData (l:list (list Z)) : Matrix float (nrows l) (ncols l) :=
- let mat : Matrix float (nrows l) (ncols l) :=
- fun i j => FfromZ (List.nth (proj1_sig j) (List.nth (proj1_sig i) l nil) (0)%Z) in
- let tmat := transpose mat in
- let maxes := vmap vmax tmat in
- let mins := vmap vmin tmat in
- fun i j => ((mat i j) - (mins j))/((maxes j)- (mins j)).
-
-Program Definition splitLastCol {nsamp ncols : nat} (data : Matrix float nsamp ncols)
- (pf : (ncols > 0)%nat)
- : (Matrix float (ncols-1) nsamp) * (Matrix float 1 nsamp) :=
- (fun i j => data j i, fun i j => data j (i + ncols-1)%nat).
-Next Obligation.
- lia.
-Defined.
-Next Obligation.
- lia.
-Defined.
-
-Definition init_env2 (dim1 dim2 dim3 : nat) (w b : string)
- (ranm1 : Matrix float dim2 dim1)
- (ranm2 : Matrix float dim3 dim2) : df_env :=
- let wvar := Name w in
- let bvar := Name b in
- let wvar1 := (Sub wvar 1, DTMatrix dim2 dim1) in
- let wvar2 := (Sub wvar 2, DTMatrix dim3 dim2) in
- let bvar1 := (Sub bvar 1, DTVector dim2) in
- let bvar2 := (Sub bvar 2, DTVector dim3) in
- (* (mk_env_entry wvar1 ranm1) :: (mk_env_entry wvar2 ranm2) :: *)
- (mk_env_entry wvar1 (fun i j => (ranm1 i j) / (Fsqrt (FfromZ (Z.of_nat dim1))))) ::
- (mk_env_entry wvar2 (fun i j => (ranm2 i j) / (Fsqrt (FfromZ (Z.of_nat dim2))))) ::
- (mk_env_entry bvar1 (ConstVector dim2 0)) ::
- (mk_env_entry bvar2 (ConstVector dim3 0)) :: nil.
-
- Definition mkNN_wisconsin (nsamp:nat) (ivar wvar bvar f1v f2v : SubVar) : (UnitDefinedFunction (DTMatrix 1 nsamp) * list var_type) :=
- let f1_activ := Max tt (Var (f1v,DTfloat) tt) (Number tt 0) in
- let f2_activ := Sigmoid f2v in
- mkNN2_Var_Mat_bias 9 15 1 nsamp ivar wvar bvar f1_activ f2_activ f1v f2v.
-
- Program Definition eval_wisconsin_batch (nsamp : nat)
- (σ:df_env)
- (normaldata: Matrix float nsamp 10) : option float :=
- let ivar := (Name "i") in
- let flnnv := (Name "NNv") in
- let outnnv := (Name "outnnv") in
- let '(NN, dvars) := mkNN_wisconsin nsamp ivar (Name "w") (Name "b") (Name "f1v") (Name "f2v") in
- @EvalNNinstancebatch 9 nsamp 1 ivar (CrossEntropy flnnv outnnv) flnnv outnnv
- NN σ (splitLastCol normaldata _).
- Next Obligation.
- lia.
- Qed.
-
-Program Definition wisconsin_instance_batch (nsamp : nat)
- (σ:df_env)
- (normaldata: Matrix float nsamp 10)
- : (df_env * (UnitDefinedFunction DTfloat)) * (list var_type) :=
- let ivar := (Name "i") in
- let flnnv := (Name "NNv") in
- let outnnv := (Name "outnnv") in
- let '(NN, dvars) := mkNN_wisconsin nsamp ivar (Name "w") (Name "b") (Name "f1v") (Name "f2v") in
- (@NNinstancebatch 9 nsamp 1 ivar (CrossEntropy flnnv outnnv) flnnv outnnv
- NN
- σ
- (splitLastCol normaldata _), dvars).
-Next Obligation.
- lia.
-Qed.
-
-CoFixpoint zeronoise : Stream float := Cons 0 zeronoise.
-
-(*
-https://towardsdatascience.com/predict-malignancy-in-breast-cancer-tumors-with-your-own-neural-network-and-the-wisconsin-dataset-76271a05e941
-*)
-
-Definition wisconsin_test (nsamp count : nat)
- (σ:df_env)
- (normaldata: Matrix float nsamp 10): list float :=
- let '(nninst, dvars) := wisconsin_instance_batch nsamp σ normaldata in
- let lr := fun _ => 1 / (FfromZ 100) in
- let onenv := fst (optimize_steps_tree_backprop 0 count (snd nninst) (fst nninst) lr
- zeronoise dvars) in
- match onenv with
- | Some nenv => match df_eval nenv (snd nninst) with
- | Some val => val :: nil
- | _ => nil
- end
- | _ => nil
- end.
-
-Definition wisconsin_test_env (nsamp count : nat)
- (σ:df_env)
- (normaldata: Matrix float nsamp 10): df_env :=
- let '(nninst, dvars) := wisconsin_instance_batch nsamp σ normaldata in
- let lr := fun _ => 1 / (FfromZ 100) in
- let onenv := fst (optimize_steps_tree_backprop 0 count (snd nninst) (fst nninst) lr
- zeronoise dvars) in
- match onenv with
- | Some nenv => nenv
- | _ => nil
- end.
-
-
-Example xvar:var_type := (Name "x", DTfloat).
-Example xfun:UnitDefinedFunction DTfloat := Var xvar tt.
-Example tquad:UnitDefinedFunction DTfloat := Times tt xfun xfun.
-Example quad:UnitDefinedFunction DTfloat := Minus tt (Times tt xfun xfun) (Number tt 1).
-Example squad:UnitDefinedFunction DTfloat := Minus tt (Square tt xfun) (Number tt 1).
-Example env : df_env := (mk_env_entry xvar (FfromZ 5))::nil.
-
-
-Example gradenv := match df_eval_backprop_deriv env quad (gradenv_init (xvar :: nil)) 1 with
- | Some gradenv => gradenv
- | _ => nil
- end.
-
-Example gradenv_tree :=
- match df_eval_tree env quad with
- | Some df_tree =>
- match df_eval_tree_backprop_deriv nil df_tree (gradenv_init (xvar :: nil)) 1 with
- | Some gradenv => gradenv
- | _ => nil
- end
- | _ => nil
- end.
-
-Example wisconsin_gradenv (nsamp : nat)
- (σ:df_env)
- (normaldata: Matrix float nsamp 10) : df_env :=
- let '((env,nn),dvars) := wisconsin_instance_batch nsamp σ normaldata in
- match df_eval_backprop_deriv env nn (gradenv_init dvars) 1 with
- | Some gradenv => gradenv
- | _ => nil
- end.
-
-Example wisconsin_gradenv_tree (nsamp : nat)
- (σ:df_env)
- (normaldata: Matrix float nsamp 10) : df_env :=
- let '((env,nn),dvars) := wisconsin_instance_batch nsamp σ normaldata in
- match df_eval_tree env nn with
- | Some df_tree =>
- match df_eval_tree_backprop_deriv nil df_tree (gradenv_init dvars) 1 with
- | Some gradenv => gradenv
- | _ => nil
- end
- | _ => nil
- end.
-
- Definition mkperceptron (n1 n2 : nat) (ivar wvar bvar : SubVar) (f_activ : UnitDefinedFunction DTfloat) (f_activ_var : SubVar) : (UnitDefinedFunction (DTVector n2) * list var_type) :=
- let mat1 := (wvar, DTMatrix n2 n1) in
- let b1 := (bvar, DTVector n2) in
- let ivec := Var (ivar, DTVector n1) tt in
- let N1 := mkNN_bias_step n1 n2 ivec (Var mat1 tt) (Var b1 tt) f_activ_var f_activ in
- (N1, mat1 :: b1 :: nil).
-
- Definition mkNN_test1 :=
- let ivar := Name "i" in
- let wvar := Name "w" in
- let bvar := Name "b" in
- let f1v := Name "f1" in
- let f1_activ := Max tt (Var (f1v,DTfloat) tt) (Number tt 0) in
- let '(nn, dvars) := mkperceptron 2 2 ivar wvar bvar f1_activ f1v in
- let winit := mk_env_entry (wvar, DTMatrix 2 2) (ConstMatrix 2 2 1) in
- let binit := mk_env_entry (bvar, DTVector 2) (ConstVector 2 1) in
- let datain := ConstVector 2 1 in
- let env := (mk_env_entry (ivar, DTVector 2) (ConstVector 2 1)):: winit :: binit :: nil in
- (env, nn).
-
- Definition mkNN_test :=
- let ivar := Name "i" in
- let wvar := Name "w" in
- let bvar := Name "b" in
- let f1v := Name "f1" in
-(* let f1_activ := Max tt (Var (f1v,DTfloat) tt) (Number tt 0) in *)
- let f1_activ := Sigmoid f1v in
- let '(nn, dvars) := mkperceptron 2 2 ivar wvar bvar f1_activ f1v in
- let datain := ConstVector 2 1 in
- let dataout := ConstVector 2 0 in
- let loss_nnvar := Name "lv" in
- let loss_ovar := Name "ov" in
- let floss := CrossEntropy loss_nnvar loss_ovar in
- let winit := mk_env_entry (wvar, DTMatrix 2 2) (ConstMatrix 2 2 (FfromZ 2)) in
- let binit := mk_env_entry (bvar, DTVector 2) (ConstVector 2 (FfromZ 3)) in
- let env := winit :: binit :: nil in
- (NNinstance1samp ivar floss loss_nnvar loss_ovar nn env (datain, dataout), dvars).
-
-Definition NN_test (count : nat) : list float :=
- let '(nninst, dvars) := mkNN_test in
- let onenv := fst (optimize_steps_tree_backprop 0 count (snd nninst) (fst nninst) harmonic_lr
- zeronoise dvars) in
- match onenv with
- | Some nenv => match df_eval nenv (snd nninst) with
- | Some val => val :: nil
- | _ => nil
- end
- | _ => nil
- end.
-
-Example NN_test_gradenv : df_env :=
- let '((env, nn), dvars) := mkNN_test in
- match df_eval_backprop_deriv env nn (gradenv_init dvars) 1 with
- | Some gradenv => gradenv
- | _ => nil
- end.
-
-Example NN_test_gradenv_tree :df_env :=
- let '((env, nn), dvars) := mkNN_test in
- match df_eval_tree env nn with
- | Some df_tree =>
- match df_eval_tree_backprop_deriv nil df_tree (gradenv_init dvars) 1 with
- | Some gradenv => gradenv
- | _ => nil
- end
- | _ => nil
- end.
-
-
-Definition NN_test_env (count : nat) : df_env :=
- let '(nninst, dvars) := mkNN_test in
- let onenv := fst (optimize_steps_tree_backprop 0 count (snd nninst) (fst nninst)
- harmonic_lr zeronoise dvars) in
- match onenv with
- | Some nenv => nenv
- | _ => nil
- end.
-
-Definition NN_test_NN :=
- let '((env,nn), dvars) := mkNN_test in
- nn.
-
-Definition NN_test_val : list float :=
- let '((env,nn), dvars) := mkNN_test in
- match df_eval env nn with
- | Some val => val :: nil
- | _ => nil
- end.
-
-Definition test_update_val_gradenv
- := update_val_gradenv gradenv xvar 1 (FfromZ 5) 0.
-
-Definition test_update : df_env :=
- (fst (update_entry (mk_env_entry xvar (FfromZ 5)) gradenv 1 zeronoise))::nil.
-
-Definition test_optimize_step_backprop
- (step : nat) (df : UnitDefinedFunction DTfloat) (σ:df_env)
- (noise_st : Stream float) (dvars : list var_type) : df_env :=
- match fst (optimize_step_backprop step df σ harmonic_lr noise_st dvars) with
- | Some env => env
- | _ => nil
- end.
-
-Definition test_optimize_step_tree_backprop
- (step : nat) (df : UnitDefinedFunction DTfloat) (σ:df_env)
- (noise_st : Stream float) (dvars : list var_type) : df_env :=
- match fst (optimize_step_tree_backprop step df σ harmonic_lr noise_st dvars) with
- | Some env => env
- | _ => nil
- end.
-
-Example testopt := test_optimize_step_backprop 0 quad env zeronoise (xvar :: nil).
-Example testreeopt := test_optimize_step_tree_backprop 0 quad env zeronoise (xvar :: nil).
-
-Example opt := fst (optimize_steps 0 2 quad env ((fst xvar) :: nil) zeronoise).
-Example opt2 := fst (optimize_steps_tree_backprop 0 2 quad env harmonic_lr zeronoise (xvar :: nil)).
-
-Example val := 1+1.
-
-End GenNN.
-
-(*
-Local Instance floatish_interval : floatish := floatish_interval_gen 53.
-Compute df_eval NN_test_env NN_test_NN.
-*)
diff --git a/coq/NeuralNetworks/NN.v b/coq/NeuralNetworks/NN.v
deleted file mode 100644
index 5ccf88c5..00000000
--- a/coq/NeuralNetworks/NN.v
+++ /dev/null
@@ -1,59 +0,0 @@
-Require Import Reals.Rbase.
-Require Import Reals.Rfunctions.
-Require Import Arith.
-Require Import String.
-Require Import Vector.
-
-Require Import AxiomaticNormedRealVectorSpace.
-
-Module NN.
-
-Section SeriesDivergence.
-
-Definition converges (s: nat -> R) :=
- exists sum:R, infinite_sum s sum.
-
-Definition diverges (s: nat -> R) : Prop :=
- ~(converges s).
-
-Definition diverges_right (s: nat -> R) : Prop :=
- forall m: R,
- exists N: nat,
- forall n: nat,
- n >= N -> Rgt (s n) m.
-
-Definition diverges_left (s: nat -> R) : Prop :=
- forall m: R,
- exists N: nat,
- forall n:nat,
- n >= N -> Rlt (s n) m.
-
-End SeriesDivergence.
-
-Section AssumptionC.
-
-Local Open Scope R_scope.
-Definition Assumption_C_1 (ak : nat -> R) : Prop :=
- let
- ak_squared (n : nat) := (ak n)^2
- in
- (forall n, ak n >= 0) /\
- diverges_right ak /\
- converges ak_squared.
-
-Definition Assumption_C_2 (s: Set) (x: nat -> rvector s) : Prop :=
- exists M : R,
- forall k: nat, norm s (x k) < M.
-
-(* TODO *)
-Definition Assumption_C_3 (zeta: nat -> R) : Prop :=
- ZeroMeanBoundedVariance zeta.
-
-Definition Assumption_C (s: Set) (zeta: nat -> R) (alpha: nat -> R) (x : nat -> rvector s) : Prop :=
- Assumption_C_1 alpha /\ Assumption_C_2 s x /\ Assumption_C_3 zeta.
-
-Local Close Scope R_scope.
-
-End AssumptionC.
-
-End NN.
diff --git a/coq/NeuralNetworks/nDefinedFunctions.v b/coq/NeuralNetworks/nDefinedFunctions.v
deleted file mode 100644
index d5e3f8b6..00000000
--- a/coq/NeuralNetworks/nDefinedFunctions.v
+++ /dev/null
@@ -1,15740 +0,0 @@
-Require Import String.
-Require Import EquivDec.
-Require Import RelationClasses.
-Require Import List.
-Require Import Permutation.
-Require Import NPeano.
-Require Import Lra Lia.
-Require Reals.
-Require Import Eqdep_dec.
-
-Require Import Floatish.
-Require Import Utils.
-Require Import derivlemmas.
-Require VectorDef.
-Require Import nvector.
-
-Set Bullet Behavior "Strict Subproofs".
-
-Section DefinedFunctions.
-
- Context {floatish_impl:floatish}.
- Local Open Scope float.
-
-(* Declare Scope df_scope. *)
-
-(* in pytorch relu(f)' if f <=0 then 0 else f' *)
-(* in pytorch abs(f)' = f'*sign(f) *)
-(* max(a,b)' = if a<=b then b' else a' *)
-(* min(a,b)' = if a>=b then b' else a' *)
-(* x = Variable(torch.tensor(0.0), requires_grad=True) *)
-(* z = torch.min(x*x, x); z.backward(); print(x.grad) = 1 *)
-(* x.grad.data.zero_() between tests *)
-(* relu behaves like max(x, 0), not max(0,x), i.e. relu(x)' at 0 = 0 *)
-
-
- Section Definitions.
-
- Definition var := string.
-
- Inductive SubVar : Set :=
- | Name (s : string)
- | Sub (v : SubVar) (i : nat).
-
-
- Definition var_dec : forall v1 v2 : SubVar, {v1 = v2} + {v1 <> v2}.
- Proof.
- decide equality.
- apply string_dec.
- apply Nat.eq_dec.
- Defined.
-
- Global Instance var_eqdec : EqDec SubVar eq.
- Proof.
- intros x y.
- apply var_dec.
- Defined.
-
- (* A subset of defined functions *)
-
-
- Inductive definition_function_types
- := DTfloat
- | DTVector (n:nat)
- | DTMatrix (m n:nat).
-
- Definition definition_function_types_interp (dft:definition_function_types) : Type
- := match dft with
- | DTfloat => float
- | DTVector n => vector float n
- | DTMatrix m n => matrix float m n
- end.
-
- Inductive data_type : definition_function_types -> Type
- := DataFloat : data_type DTfloat
- | DataVector n (v:vector float n) : data_type (DTVector n)
- | DataMatrix m n (mat:matrix float m n) : data_type (DTMatrix m n).
-
- Definition var_type := (SubVar * definition_function_types)%type.
-
- Definition definition_function_types_dec : forall v1 v2 : definition_function_types, {v1 = v2} + {v1 <> v2}.
- Proof.
- decide equality; apply Nat.eq_dec.
- Defined.
-
- Definition vart_dec : forall v1 v2 : var_type, {v1 = v2} + {v1 <> v2}.
- Proof.
- decide equality.
- - apply definition_function_types_dec.
- - apply var_dec.
- Defined.
-
- Global Instance vart_eqdec : EqDec var_type eq.
- Proof.
- intros ??.
- apply vart_dec.
- Defined.
-
- Lemma var_type_UIP_refl {x:var_type} (e:x=x) : e = eq_refl x.
- Proof.
- apply (UIP_dec vart_dec).
- Qed.
-
- Lemma definition_function_types_UIP_refl {x:definition_function_types} (e:x=x) : e = eq_refl x.
- Proof.
- apply (UIP_dec definition_function_types_dec).
- Qed.
-
- Definition env_entry_type := {v:var_type & definition_function_types_interp (snd v)}.
- Definition df_env := list env_entry_type.
-
- Definition mk_env_entry v e : env_entry_type
- := let P := fun xv => definition_function_types_interp (snd xv) in
- existT P v e.
-
- Definition UnitAnn: definition_function_types->Type := fun _ => unit.
- Definition EvalAnn: definition_function_types->Type := definition_function_types_interp.
-
- Inductive DefinedFunction {Ann:definition_function_types->Type} : definition_function_types -> Type :=
- | Number (ann:Ann DTfloat) (x : float) : DefinedFunction DTfloat
- | Constant {t:definition_function_types} (ann:Ann t) (x : definition_function_types_interp t) : DefinedFunction t
-
- | DVector {n} (ann:Ann (DTVector n)) (x : vector (DefinedFunction DTfloat) n) : DefinedFunction (DTVector n)
-
- | DMatrix {n m} (ann:Ann (DTMatrix n m)) (x : matrix (DefinedFunction DTfloat) n m) : DefinedFunction (DTMatrix n m)
-
-
- | Var (v : var_type) (ann: Ann (snd v)) : DefinedFunction (snd v)
- | Plus (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | Minus (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | Times (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | Divide (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | Square (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | Exp (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | Log (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | Abs (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | Sign (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | PSign (ann:Ann DTfloat) (e : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | Max (ann:Ann DTfloat) (l r : DefinedFunction DTfloat) : DefinedFunction DTfloat
- | VectorDot {n} (ann:Ann DTfloat) (l r: DefinedFunction (DTVector n)) : DefinedFunction DTfloat
- | VectorSum {n} (ann:Ann DTfloat) (v: DefinedFunction (DTVector n)) : DefinedFunction DTfloat
- | MatrixSum {m n} (ann:Ann DTfloat) (v: DefinedFunction (DTMatrix m n)) : DefinedFunction DTfloat
- | VectorElem {n} (ann:Ann DTfloat) (l:DefinedFunction (DTVector n)) (i:{x:nat|x Prop)
- (f : forall (ann : UnitAnn DTfloat) (x : float),
- P DTfloat (Number ann x))
- (f0 : forall (t : definition_function_types)
- (ann : UnitAnn t) (x : definition_function_types_interp t), P t (Constant ann x))
- (f1 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (x : vector (DefinedFunction UnitAnn DTfloat) n)
- (f: vforall (P DTfloat) x),
- P (DTVector n) (DVector ann x))
- (f2 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m))
- (x : matrix (DefinedFunction UnitAnn DTfloat) n m)
- (f: mforall (P DTfloat) x),
- P (DTMatrix n m) (DMatrix ann x))
- (f3 : forall (v : var_type) (ann : UnitAnn (snd v)),
- P (snd v) (Var v ann))
- (f4 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Plus ann l r))
- (f5 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Minus ann l r))
- (f6 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Times ann l r))
- (f7 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Divide ann l r))
- (f8 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Square ann e))
- (f9 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Exp ann e))
- (f10 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Log ann e))
- (f11 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Abs ann e))
- (f12 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Sign ann e))
- (f13 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (PSign ann e))
- (f14 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Max ann l r))
- (f15 : forall (n : nat) (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn (DTVector n)),
- P (DTVector n) l ->
- forall r : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) r -> P DTfloat (VectorDot ann l r))
- (f16 : forall (n : nat) (ann : UnitAnn DTfloat)
- (v : DefinedFunction UnitAnn (DTVector n)), P (DTVector n) v -> P DTfloat (VectorSum ann v))
- (f17 : forall (m n : nat) (ann : UnitAnn DTfloat)
- (v : DefinedFunction UnitAnn (DTMatrix m n)),
- P (DTMatrix m n) v -> P DTfloat (MatrixSum ann v))
- (f18 : forall (n : nat) (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn (DTVector n)),
- P (DTVector n) l ->
- forall i : {x : nat | (x < n)%nat}, P DTfloat (VectorElem ann l i))
- (f19 : forall (m n : nat) (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn (DTMatrix m n)),
- P (DTMatrix m n) l ->
- forall (i : {x : nat | (x < m)%nat}) (j : {x : nat | (x < n)%nat}),
- P DTfloat (MatrixElem ann l i j))
- (f20 : forall (m n : nat) (ann : UnitAnn (DTVector m))
- (l : DefinedFunction UnitAnn (DTMatrix m n)),
- P (DTMatrix m n) l ->
- forall r : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) r -> P (DTVector m) (MatrixVectorMult ann l r))
- (f21 : forall (m n : nat) (ann : UnitAnn (DTMatrix m n))
- (l : DefinedFunction UnitAnn (DTMatrix m n)),
- P (DTMatrix m n) l ->
- forall r : DefinedFunction UnitAnn (DTVector m),
- P (DTVector m) r -> P (DTMatrix m n) (MatrixVectorAdd ann l r))
- (f22 : forall (m p n : nat) (ann : UnitAnn (DTMatrix m n))
- (l : DefinedFunction UnitAnn (DTMatrix m p)),
- P (DTMatrix m p) l ->
- forall r : DefinedFunction UnitAnn (DTMatrix p n),
- P (DTMatrix p n) r -> P (DTMatrix m n) (MatrixMult ann l r))
- (f23 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (l : DefinedFunction UnitAnn (DTVector n)),
- P (DTVector n) l ->
- forall r : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) r -> P (DTVector n) (VectorPlus ann l r))
- (f24 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (l : DefinedFunction UnitAnn (DTVector n)),
- P (DTVector n) l ->
- forall r : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) r -> P (DTVector n) (VectorMinus ann l r))
- (f25 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m))
- (l : DefinedFunction UnitAnn (DTMatrix n m)),
- P (DTMatrix n m) l ->
- forall r : DefinedFunction UnitAnn (DTMatrix n m),
- P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixPlus ann l r))
- (f26 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m))
- (l : DefinedFunction UnitAnn (DTMatrix n m)),
- P (DTMatrix n m) l ->
- forall r : DefinedFunction UnitAnn (DTMatrix n m),
- P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixMinus ann l r))
- (f27 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (x : DefinedFunction UnitAnn DTfloat),
- P DTfloat x ->
- forall l : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) l -> P (DTVector n) (VectorScalMult ann x l))
- (f28 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m))
- (x : DefinedFunction UnitAnn DTfloat),
- P DTfloat x ->
- forall l : DefinedFunction UnitAnn (DTMatrix n m),
- P (DTMatrix n m) l -> P (DTMatrix n m) (MatrixScalMult ann x l))
- (f29 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (v : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- P DTfloat s ->
- forall l : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) l -> P (DTVector n) (VectorApply ann v s l))
- (f30 : forall (m n : nat) (ann : UnitAnn (DTMatrix m n))
- (v : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- P DTfloat s ->
- forall l : DefinedFunction UnitAnn (DTMatrix m n),
- P (DTMatrix m n) l -> P (DTMatrix m n) (MatrixApply ann v s l))
- (f31 : forall (n : nat) (ann : UnitAnn DTfloat)
- (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- P DTfloat s ->
- forall l : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) l -> forall r : vector float n, P DTfloat (VLossfun ann v1 v2 s l r))
- (f32 : forall (m n : nat) (ann : UnitAnn DTfloat)
- (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- P DTfloat s ->
- forall l : DefinedFunction UnitAnn (DTMatrix m n),
- P (DTMatrix m n) l ->
- forall r : matrix float m n, P DTfloat (MLossfun ann v1 v2 s l r))
- :=
-fix
-F (d : definition_function_types)
- (d0 : DefinedFunction UnitAnn d) {struct d0} : P d d0 :=
- match d0 as d2 in (DefinedFunction _ d1) return (P d1 d2) with
- | Number ann x => f ann x
- | @Constant _ t ann x => f0 t ann x
- | @DVector _ n ann x =>
- f1 n ann x
- ((fix F1 n (x:vector (DefinedFunction UnitAnn DTfloat) n) : vforall (P DTfloat) x :=
- match x with
- | Vector.nil => Vector.Forall_nil (P DTfloat)
- | Vector.cons h _ tl => Vector.Forall_cons _ _ _ (F DTfloat h) (F1 _ tl)
- end) n x)
- | @DMatrix _ n m ann x =>
- f2 n m ann x
- ((fix F2 n m (x:matrix (DefinedFunction UnitAnn DTfloat) n m) : mforall (P DTfloat) x :=
- match x with
- | Vector.nil => Vector.Forall_nil (vforall (P DTfloat))
-
- | Vector.cons h _ tl =>
- Vector.Forall_cons _ _ _
- ((fix F1 m (x:vector (DefinedFunction UnitAnn DTfloat) m) : vforall (P DTfloat) x :=
- match x with
- | Vector.nil => Vector.Forall_nil (P DTfloat)
- | Vector.cons h _ tl => Vector.Forall_cons _ _ _ (F DTfloat h) (F1 _ tl)
- end) m h)
- (F2 _ _ tl)
- end) n m x)
- | Var v ann => f3 v ann
- | Plus ann l r => f4 ann l (F DTfloat l) r (F DTfloat r)
- | Minus ann l r => f5 ann l (F DTfloat l) r (F DTfloat r)
- | Times ann l r => f6 ann l (F DTfloat l) r (F DTfloat r)
- | Divide ann l r => f7 ann l (F DTfloat l) r (F DTfloat r)
- | Square ann e => f8 ann e (F DTfloat e)
- | Exp ann e => f9 ann e (F DTfloat e)
- | Log ann e => f10 ann e (F DTfloat e)
- | Abs ann e => f11 ann e (F DTfloat e)
- | Sign ann e => f12 ann e (F DTfloat e)
- | PSign ann e => f13 ann e (F DTfloat e)
- | Max ann l r => f14 ann l (F DTfloat l) r (F DTfloat r)
- | @VectorDot _ n ann l r => f15 n ann l (F (DTVector n) l) r (F (DTVector n) r)
- | @VectorSum _ n ann v => f16 n ann v (F (DTVector n) v)
- | @MatrixSum _ m n ann v => f17 m n ann v (F (DTMatrix m n) v)
- | @VectorElem _ n ann l i => f18 n ann l (F (DTVector n) l) i
- | @MatrixElem _ m n ann l i j => f19 m n ann l (F (DTMatrix m n) l) i j
- | @MatrixVectorMult _ m n ann l r =>
- f20 m n ann l (F (DTMatrix m n) l) r (F (DTVector n) r)
- | @MatrixVectorAdd _ m n ann l r =>
- f21 m n ann l (F (DTMatrix m n) l) r (F (DTVector m) r)
- | @MatrixMult _ m p n ann l r =>
- f22 m p n ann l (F (DTMatrix m p) l) r (F (DTMatrix p n) r)
- | @VectorPlus _ n ann l r => f23 n ann l (F (DTVector n) l) r (F (DTVector n) r)
- | @VectorMinus _ n ann l r => f24 n ann l (F (DTVector n) l) r (F (DTVector n) r)
- | @MatrixPlus _ n m ann l r => f25 n m ann l (F (DTMatrix n m) l) r (F (DTMatrix n m) r)
- | @MatrixMinus _ n m ann l r =>
- f26 n m ann l (F (DTMatrix n m) l) r (F (DTMatrix n m) r)
- | @VectorScalMult _ n ann x l => f27 n ann x (F DTfloat x) l (F (DTVector n) l)
- | @MatrixScalMult _ n m ann x l => f28 n m ann x (F DTfloat x) l (F (DTMatrix n m) l)
- | @VectorApply _ n ann v s l => f29 n ann v s (F DTfloat s) l (F (DTVector n) l)
- | @MatrixApply _ m n ann v s l =>
- f30 m n ann v s (F DTfloat s) l (F (DTMatrix m n) l)
- | @VLossfun _ n ann v1 v2 s l r =>
- f31 n ann v1 v2 s (F DTfloat s) l (F (DTVector n) l) r
- | @MLossfun _ m n ann v1 v2 s l r =>
- f32 m n ann v1 v2 s (F DTfloat s) l (F (DTMatrix m n) l) r
- end.
-
-Definition DefinedFunction_ind_simpl {Ann}
- (P : forall (d : definition_function_types), DefinedFunction Ann d -> Prop)
- (f : forall (ann : Ann DTfloat) (x : float),
- P DTfloat (Number ann x))
- (f0 : forall (t : definition_function_types)
- (ann : Ann t) (x : definition_function_types_interp t), P t (Constant ann x))
- (f1 : forall (n : nat) (ann : Ann (DTVector n))
- (x : vector (DefinedFunction Ann DTfloat) n)
- (f: vforall (P DTfloat) x),
- P (DTVector n) (DVector ann x))
- (f2 : forall (n m : nat) (ann : Ann (DTMatrix n m))
- (x : matrix (DefinedFunction Ann DTfloat) n m)
- (f: mforall (P DTfloat) x),
- P (DTMatrix n m) (DMatrix ann x))
- (f3 : forall (v : var_type) (ann : Ann (snd v)),
- P (snd v) (Var v ann))
- (f4 : forall (ann : Ann DTfloat)
- (l : DefinedFunction Ann DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Plus ann l r))
- (f5 : forall (ann : Ann DTfloat)
- (l : DefinedFunction Ann DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Minus ann l r))
- (f6 : forall (ann : Ann DTfloat)
- (l : DefinedFunction Ann DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Times ann l r))
- (f7 : forall (ann : Ann DTfloat)
- (l : DefinedFunction Ann DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Divide ann l r))
- (f8 : forall (ann : Ann DTfloat)
- (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Square ann e))
- (f9 : forall (ann : Ann DTfloat)
- (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Exp ann e))
- (f10 : forall (ann : Ann DTfloat)
- (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Log ann e))
- (f11 : forall (ann : Ann DTfloat)
- (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Abs ann e))
- (f12 : forall (ann : Ann DTfloat)
- (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (Sign ann e))
- (f13 : forall (ann : Ann DTfloat)
- (e : DefinedFunction Ann DTfloat), P DTfloat e -> P DTfloat (PSign ann e))
- (f14 : forall (ann : Ann DTfloat)
- (l : DefinedFunction Ann DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction Ann DTfloat, P DTfloat r -> P DTfloat (Max ann l r))
- (f15 : forall (n : nat) (ann : Ann DTfloat)
- (l : DefinedFunction Ann (DTVector n)),
- P (DTVector n) l ->
- forall r : DefinedFunction Ann (DTVector n),
- P (DTVector n) r -> P DTfloat (VectorDot ann l r))
- (f16 : forall (n : nat) (ann : Ann DTfloat)
- (v : DefinedFunction Ann (DTVector n)), P (DTVector n) v -> P DTfloat (VectorSum ann v))
- (f17 : forall (m n : nat) (ann : Ann DTfloat)
- (v : DefinedFunction Ann (DTMatrix m n)),
- P (DTMatrix m n) v -> P DTfloat (MatrixSum ann v))
- (f18 : forall (n : nat) (ann : Ann DTfloat)
- (l : DefinedFunction Ann (DTVector n)),
- P (DTVector n) l ->
- forall i : {x : nat | (x < n)%nat}, P DTfloat (VectorElem ann l i))
- (f19 : forall (m n : nat) (ann : Ann DTfloat)
- (l : DefinedFunction Ann (DTMatrix m n)),
- P (DTMatrix m n) l ->
- forall (i : {x : nat | (x < m)%nat}) (j : {x : nat | (x < n)%nat}),
- P DTfloat (MatrixElem ann l i j))
- (f20 : forall (m n : nat) (ann : Ann (DTVector m))
- (l : DefinedFunction Ann (DTMatrix m n)),
- P (DTMatrix m n) l ->
- forall r : DefinedFunction Ann (DTVector n),
- P (DTVector n) r -> P (DTVector m) (MatrixVectorMult ann l r))
- (f21 : forall (m n : nat) (ann : Ann (DTMatrix m n))
- (l : DefinedFunction Ann (DTMatrix m n)),
- P (DTMatrix m n) l ->
- forall r : DefinedFunction Ann (DTVector m),
- P (DTVector m) r -> P (DTMatrix m n) (MatrixVectorAdd ann l r))
- (f22 : forall (m p n : nat) (ann : Ann (DTMatrix m n))
- (l : DefinedFunction Ann (DTMatrix m p)),
- P (DTMatrix m p) l ->
- forall r : DefinedFunction Ann (DTMatrix p n),
- P (DTMatrix p n) r -> P (DTMatrix m n) (MatrixMult ann l r))
- (f23 : forall (n : nat) (ann : Ann (DTVector n))
- (l : DefinedFunction Ann (DTVector n)),
- P (DTVector n) l ->
- forall r : DefinedFunction Ann (DTVector n),
- P (DTVector n) r -> P (DTVector n) (VectorPlus ann l r))
- (f24 : forall (n : nat) (ann : Ann (DTVector n))
- (l : DefinedFunction Ann (DTVector n)),
- P (DTVector n) l ->
- forall r : DefinedFunction Ann (DTVector n),
- P (DTVector n) r -> P (DTVector n) (VectorMinus ann l r))
- (f25 : forall (n m : nat) (ann : Ann (DTMatrix n m))
- (l : DefinedFunction Ann (DTMatrix n m)),
- P (DTMatrix n m) l ->
- forall r : DefinedFunction Ann (DTMatrix n m),
- P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixPlus ann l r))
- (f26 : forall (n m : nat) (ann : Ann (DTMatrix n m))
- (l : DefinedFunction Ann (DTMatrix n m)),
- P (DTMatrix n m) l ->
- forall r : DefinedFunction Ann (DTMatrix n m),
- P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixMinus ann l r))
- (f27 : forall (n : nat) (ann : Ann (DTVector n))
- (x : DefinedFunction Ann DTfloat),
- P DTfloat x ->
- forall l : DefinedFunction Ann (DTVector n),
- P (DTVector n) l -> P (DTVector n) (VectorScalMult ann x l))
- (f28 : forall (n m : nat) (ann : Ann (DTMatrix n m))
- (x : DefinedFunction Ann DTfloat),
- P DTfloat x ->
- forall l : DefinedFunction Ann (DTMatrix n m),
- P (DTMatrix n m) l -> P (DTMatrix n m) (MatrixScalMult ann x l))
- (f29 : forall (n : nat) (ann : Ann (DTVector n))
- (v : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- forall l : DefinedFunction Ann (DTVector n),
- P (DTVector n) l -> P (DTVector n) (VectorApply ann v s l))
- (f30 : forall (m n : nat) (ann : Ann (DTMatrix m n))
- (v : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- forall l : DefinedFunction Ann (DTMatrix m n),
- P (DTMatrix m n) l -> P (DTMatrix m n) (MatrixApply ann v s l))
- (f31 : forall (n : nat) (ann : Ann DTfloat)
- (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- forall l : DefinedFunction Ann (DTVector n),
- P (DTVector n) l -> forall r : vector float n, P DTfloat (VLossfun ann v1 v2 s l r))
- (f32 : forall (m n : nat) (ann : Ann DTfloat)
- (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- forall l : DefinedFunction Ann (DTMatrix m n),
- P (DTMatrix m n) l ->
- forall r : matrix float m n, P DTfloat (MLossfun ann v1 v2 s l r))
- :=
-fix
-F (d : definition_function_types)
- (d0 : DefinedFunction Ann d) {struct d0} : P d d0 :=
- match d0 as d2 in (DefinedFunction _ d1) return (P d1 d2) with
- | Number ann x => f ann x
- | @Constant _ t ann x => f0 t ann x
- | @DVector _ n ann x =>
- f1 n ann x
- ((fix F1 n (x:vector (DefinedFunction Ann DTfloat) n) : vforall (P DTfloat) x :=
- match x with
- | Vector.nil => Vector.Forall_nil (P DTfloat)
- | Vector.cons h _ tl => Vector.Forall_cons _ _ _ (F DTfloat h) (F1 _ tl)
- end) n x)
- | @DMatrix _ n m ann x =>
- f2 n m ann x
- ((fix F2 n m (x:matrix (DefinedFunction Ann DTfloat) n m) : mforall (P DTfloat) x :=
- match x with
- | Vector.nil => Vector.Forall_nil (vforall (P DTfloat))
-
- | Vector.cons h _ tl =>
- Vector.Forall_cons _ _ _
- ((fix F1 m (x:vector (DefinedFunction Ann DTfloat) m) : vforall (P DTfloat) x :=
- match x with
- | Vector.nil => Vector.Forall_nil (P DTfloat)
- | Vector.cons h _ tl => Vector.Forall_cons _ _ _ (F DTfloat h) (F1 _ tl)
- end) m h)
- (F2 _ _ tl)
- end) n m x)
- | Var v ann => f3 v ann
- | Plus ann l r => f4 ann l (F DTfloat l) r (F DTfloat r)
- | Minus ann l r => f5 ann l (F DTfloat l) r (F DTfloat r)
- | Times ann l r => f6 ann l (F DTfloat l) r (F DTfloat r)
- | Divide ann l r => f7 ann l (F DTfloat l) r (F DTfloat r)
- | Square ann e => f8 ann e (F DTfloat e)
- | Exp ann e => f9 ann e (F DTfloat e)
- | Log ann e => f10 ann e (F DTfloat e)
- | Abs ann e => f11 ann e (F DTfloat e)
- | Sign ann e => f12 ann e (F DTfloat e)
- | PSign ann e => f13 ann e (F DTfloat e)
- | Max ann l r => f14 ann l (F DTfloat l) r (F DTfloat r)
- | @VectorDot _ n ann l r => f15 n ann l (F (DTVector n) l) r (F (DTVector n) r)
- | @VectorSum _ n ann v => f16 n ann v (F (DTVector n) v)
- | @MatrixSum _ m n ann v => f17 m n ann v (F (DTMatrix m n) v)
- | @VectorElem _ n ann l i => f18 n ann l (F (DTVector n) l) i
- | @MatrixElem _ m n ann l i j => f19 m n ann l (F (DTMatrix m n) l) i j
- | @MatrixVectorMult _ m n ann l r =>
- f20 m n ann l (F (DTMatrix m n) l) r (F (DTVector n) r)
- | @MatrixVectorAdd _ m n ann l r =>
- f21 m n ann l (F (DTMatrix m n) l) r (F (DTVector m) r)
- | @MatrixMult _ m p n ann l r =>
- f22 m p n ann l (F (DTMatrix m p) l) r (F (DTMatrix p n) r)
- | @VectorPlus _ n ann l r => f23 n ann l (F (DTVector n) l) r (F (DTVector n) r)
- | @VectorMinus _ n ann l r => f24 n ann l (F (DTVector n) l) r (F (DTVector n) r)
- | @MatrixPlus _ n m ann l r => f25 n m ann l (F (DTMatrix n m) l) r (F (DTMatrix n m) r)
- | @MatrixMinus _ n m ann l r =>
- f26 n m ann l (F (DTMatrix n m) l) r (F (DTMatrix n m) r)
- | @VectorScalMult _ n ann x l => f27 n ann x (F DTfloat x) l (F (DTVector n) l)
- | @MatrixScalMult _ n m ann x l => f28 n m ann x (F DTfloat x) l (F (DTMatrix n m) l)
- | @VectorApply _ n ann v s l => f29 n ann v s l (F (DTVector n) l)
- | @MatrixApply _ m n ann v s l =>
- f30 m n ann v s l (F (DTMatrix m n) l)
- | @VLossfun _ n ann v1 v2 s l r =>
- f31 n ann v1 v2 s l (F (DTVector n) l) r
- | @MLossfun _ m n ann v1 v2 s l r =>
- f32 m n ann v1 v2 s l (F (DTMatrix m n) l) r
- end.
-
- Definition get_annotation {Ann T} (df:DefinedFunction Ann T) : Ann T
- := match df with
- | Number ann _ => ann
- | Constant _ ann _ => ann
- | DVector _ ann _ => ann
- | DMatrix _ _ ann _ => ann
- | Var _ ann => ann
- | Plus ann _ _ => ann
- | Minus ann _ _ => ann
- | Times ann _ _ => ann
- | Divide ann _ _ => ann
- | Square ann _ => ann
- | Exp ann _ => ann
- | Log ann _ => ann
- | Abs ann _ => ann
- | Sign ann _ => ann
- | PSign ann _ => ann
- | Max ann _ _ => ann
- | VectorDot _ ann _ _ => ann
- | VectorSum _ ann _ => ann
- | MatrixSum _ _ ann _ => ann
- | VectorElem _ ann _ _ => ann
- | MatrixElem _ _ ann _ _ _ => ann
- | MatrixVectorMult _ _ ann _ _ => ann
- | MatrixVectorAdd _ _ ann _ _ => ann
- | MatrixMult _ _ _ ann _ _ => ann
- | VectorPlus _ ann _ _ => ann
- | VectorMinus _ ann _ _ => ann
- | MatrixPlus _ _ ann _ _ => ann
- | MatrixMinus _ _ ann _ _ => ann
- | VectorScalMult _ ann _ _ => ann
- | MatrixScalMult _ _ ann _ _ => ann
- | VectorApply _ ann _ _ _ => ann
- | MatrixApply _ _ ann _ _ _ => ann
- | VLossfun _ ann _ _ _ _ _ => ann
- | MLossfun _ _ ann _ _ _ _ _ => ann
- end.
-
- Definition dft_eq_dec :
- forall (t1 t2 : definition_function_types), {t1 = t2} + {t1 <> t2}.
- Proof.
- decide equality.
- decide equality.
- apply Nat.eq_dec.
- apply Nat.eq_dec.
- Defined.
-
- Global Instance dft_eqdec : EqDec definition_function_types eq.
- Proof.
- intros ??.
- apply dft_eq_dec.
- Defined.
-
- End Definitions.
-
- Tactic Notation "DefinedFunction_cases" tactic(first) ident(c) :=
- first;
- [ Case_aux c "Number"%string
- | Case_aux c "Constant"%string
- | Case_aux c "DVector"%string
- | Case_aux c "DMatrix"%string
- | Case_aux c "Var"%string
- | Case_aux c "Plus"%string
- | Case_aux c "Minus"%string
- | Case_aux c "Times"%string
- | Case_aux c "Divide"%string
- | Case_aux c "Square"%string
- | Case_aux c "Exp"%string
- | Case_aux c "Log"%string
- | Case_aux c "Abs"%string
- | Case_aux c "Sign"%string
- | Case_aux c "PSign"%string
- | Case_aux c "Max"%string
- | Case_aux c "VectorDot"%string
- | Case_aux c "VectorSum"%string
- | Case_aux c "MatrixSum"%string
- | Case_aux c "VectorElem"%string
- | Case_aux c "MatrixElem"%string
- | Case_aux c "MatrixVectorMult"%string
- | Case_aux c "MatrixVectorAdd"%string
- | Case_aux c "MatrixMult"%string
- | Case_aux c "VectorPlus"%string
- | Case_aux c "VectorMinus"%string
- | Case_aux c "MatrixPlus"%string
- | Case_aux c "MatrixMinus"%string
- | Case_aux c "VectorScalMult"%string
- | Case_aux c "MatrixScalMult"%string
- | Case_aux c "VectorApply"%string
- | Case_aux c "MatrixApply"%string
- | Case_aux c "VLossfun"%string
- | Case_aux c "MLossfun"%string].
-
-
- Ltac refl_simpler :=
- repeat
- match goal with
- | [H: @eq var_type _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H)
- | [H: @equiv var_type _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H)
- | [H: @eq definition_function_types _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H)
- | [H: @equiv definition_function_types _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H)
- end.
-
-
- Definition df_plus (df1 df2 : DefinedFunction UnitAnn DTfloat) : DefinedFunction UnitAnn DTfloat :=
- Plus tt df1 df2.
-
- Definition df_times (df1 df2 : DefinedFunction UnitAnn DTfloat) : DefinedFunction UnitAnn DTfloat :=
- Times tt df1 df2.
-
- Definition defined_sum {m} (v:vector (DefinedFunction UnitAnn DTfloat) m) : DefinedFunction UnitAnn DTfloat
- := Vector.fold_right (fun a b => Plus tt a b) v (Number tt 0).
-
- Definition vsum {m:nat} (v:vector float m) : float
- := Vector.fold_right Fplus v 0.
-
- Definition msum {m n:nat} (v:matrix float m n) : float :=
- vsum (vmap vsum v).
-
- Definition transpose {A} {n m:nat} (mat:matrix A n m) :=
- build_matrix (fun i j => mnth mat j i).
-
-(* defined in nvector
- Definition matrix_vector_mult {n m} (l : Matrix float n m)(r : vector float m) : vector float n :=
- fun i => vsum (fun j => (l i j) * (r j)).
-
- Definition matrix_vector_add {n m} (l : Matrix float n m) (r : vector float n) : Matrix float n m := fun i j => (l i j) + (r i).
-
- Definition matrix_mult {n m p} (l : Matrix float n m)(r : Matrix float m p) : Matrix float n p :=
- fun i k => vsum (fun j => (l i j) * (r j k)).
-*)
-
- Section deriv.
-
-
- Section subst.
-
- Definition substvar {Ann} (v vv:var_type) (e':DefinedFunction Ann (snd v)) (e:DefinedFunction Ann (snd vv)) : (DefinedFunction Ann (snd vv)) :=
- match snd v == snd vv with
- | left pf => eq_rect _ (fun t => DefinedFunction Ann t) e' _ pf
- | right pf => e
- end.
-
- Fixpoint df_subst {T Ann} (df: DefinedFunction Ann T) (v:var_type) (e':DefinedFunction UnitAnn (snd v)) {struct df} :=
- match df with
- | Number _ x => Number tt x
- | Constant t _ x => Constant tt x
- | DVector n _ df => DVector tt (vmap (fun x => df_subst x v e') df)
- | DMatrix n m _ df => DMatrix tt (mmap (fun x => df_subst x v e') df)
- | Var vvar _ => substvar v vvar e' (Var vvar tt)
- | Plus _ l r => Plus tt (df_subst l v e') (df_subst r v e')
- | Times _ l r => Times tt (df_subst l v e') (df_subst r v e')
- | Minus _ l r => Minus tt (df_subst l v e') (df_subst r v e')
- | Divide _ l r => Divide tt (df_subst l v e') (df_subst r v e')
- | Square _ e => Square tt (df_subst e v e')
- | Exp _ e => Exp tt (df_subst e v e')
- | Log _ e => Log tt (df_subst e v e')
- | Abs _ e => Abs tt (df_subst e v e')
- | Sign _ e => Sign tt (df_subst e v e')
- | PSign _ e => PSign tt (df_subst e v e')
- | Max _ l r => Max tt (df_subst l v e') (df_subst r v e')
- | VectorElem n _ l i => VectorElem tt (df_subst l v e') i
- | MatrixElem m n _ l i j => MatrixElem tt (df_subst l v e') i j
- | VectorDot n _ l r =>
- VectorDot tt (df_subst l v e') (df_subst r v e')
- | VectorSum n _ e =>
- VectorSum tt (df_subst e v e')
- | MatrixSum n m _ e =>
- MatrixSum tt (df_subst e v e')
- | VectorScalMult n _ x r =>
- VectorScalMult tt (df_subst x v e') (df_subst r v e')
- | MatrixScalMult n m _ x r =>
- MatrixScalMult tt (df_subst x v e') (df_subst r v e')
- | MatrixVectorMult n m _ l r =>
- MatrixVectorMult tt (df_subst l v e') (df_subst r v e')
- | MatrixVectorAdd n m _ l r =>
- MatrixVectorAdd tt (df_subst l v e') (df_subst r v e')
- | MatrixMult n m p _ l r =>
- MatrixMult tt (df_subst l v e') (df_subst r v e')
- | VectorPlus n _ l r =>
- VectorPlus tt (df_subst l v e') (df_subst r v e')
- | VectorMinus n _ l r =>
- VectorMinus tt (df_subst l v e') (df_subst r v e')
- | MatrixPlus n m _ l r =>
- MatrixPlus tt (df_subst l v e') (df_subst r v e')
- | MatrixMinus n m _ l r =>
- MatrixMinus tt (df_subst l v e') (df_subst r v e')
- | VectorApply n _ x s l =>
- VectorApply tt x s (df_subst l v e')
- | MatrixApply n m _ x s l =>
- MatrixApply tt x s (df_subst l v e')
- | VLossfun n _ v1 v2 s l r =>
- VLossfun tt v1 v2 s (df_subst l v e') r
- | MLossfun n m _ v1 v2 s l r =>
- MLossfun tt v1 v2 s (df_subst l v e') r
- end.
-
- Definition df_substp {T Ann} :=
- fun e (ve':{v:var_type & DefinedFunction UnitAnn (snd v)}) =>
- @df_subst T Ann e (projT1 ve') (projT2 ve').
-
- Definition df_subst_list {T} (e:DefinedFunction UnitAnn T)
- (l:list {v:var_type & DefinedFunction UnitAnn (snd v)}) : DefinedFunction UnitAnn T
- := fold_left (@df_substp T UnitAnn) l e.
-
- End subst.
-
-
-(* restrict to scalar v? *)
- Fixpoint df_deriv {T} (df:DefinedFunction UnitAnn T) (v:var_type) {struct df} : DefinedFunction UnitAnn T
- := (match df with
- | Number _ _ => Number tt 0
- | Constant t _ x => Constant tt
- match t return definition_function_types_interp t with
- | DTfloat => 0
- | DTVector n => ConstVector n 0
- | DTMatrix m n => ConstMatrix m n 0
- end
- | DVector n _ df => DVector tt (vmap (fun x => df_deriv x v) df)
- | DMatrix n m _ df => DMatrix tt (mmap (fun x => df_deriv x v) df)
- | Var x _ => Constant tt
- match snd x as y return definition_function_types_interp y with
- | DTfloat => if x == v then 1 else 0
- | DTVector n => ConstVector n (if x == v then 1 else 0)
- | DTMatrix m n => ConstMatrix m n (if x == v then 1 else 0)
- end
- | Plus _ l r => Plus tt (df_deriv l v) (df_deriv r v)
- | Minus _ l r => Minus tt (df_deriv l v) (df_deriv r v)
- | Times _ l r => Plus tt (Times tt l (df_deriv r v))
- (Times tt (df_deriv l v) r)
- | Divide _ l r => Minus tt
- (Divide tt (df_deriv l v) r)
- (Divide tt (Times tt l (df_deriv r v))
- (Times tt r r))
- | Square _ e => Times tt
- (Times tt (Number tt 2) e) (df_deriv e v)
- | Exp _ e => Times tt (df_deriv e v) (Exp tt e)
- | Log _ e => Divide tt (df_deriv e v) e
- | Abs _ e => Times tt (df_deriv e v) (Sign tt e)
- | Sign _ e => Number tt 0
- | PSign _ e => Number tt 0
- | Max _ l r => Divide tt
- (Plus tt
- (Times tt (Minus tt
- (df_deriv r v)
- (df_deriv l v))
- (PSign tt (Minus tt r l)))
- (Plus tt (df_deriv r v) (df_deriv l v)))
- (Number tt 2)
- | VectorElem n _ l i => VectorElem tt (df_deriv l v) i
- | MatrixElem m n _ l i j => MatrixElem tt (df_deriv l v) i j
- | VectorDot n _ l r =>
- let ll := df_deriv l v in
- let rr := df_deriv r v in
- Plus tt (VectorDot tt ll r) (VectorDot tt l rr)
- | VectorSum n _ l =>
- let ll := df_deriv l v in
- VectorSum tt ll
- | MatrixSum m n _ l =>
- let ll := df_deriv l v in
- MatrixSum tt ll
- | VectorScalMult n _ x r =>
- let xx := df_deriv x v in
- let rr := df_deriv r v in
- VectorPlus tt
- (VectorScalMult tt xx r)
- (VectorScalMult tt x rr)
- | MatrixScalMult n m _ x r =>
- let xx := df_deriv x v in
- let rr := df_deriv r v in
- MatrixPlus tt
- (MatrixScalMult tt xx r) (MatrixScalMult tt x rr)
- | MatrixVectorMult n m _ l r =>
- let ll := df_deriv l v in
- let rr := df_deriv r v in
- VectorPlus tt (MatrixVectorMult tt ll r)
- (MatrixVectorMult tt l rr)
- | MatrixVectorAdd n m _ l r =>
- let ll := df_deriv l v in
- let rr := df_deriv r v in
- MatrixVectorAdd tt ll rr
- | MatrixMult n m p _ l r =>
- let ll := df_deriv l v in
- let rr := df_deriv r v in
- MatrixPlus tt (MatrixMult tt ll r) (MatrixMult tt l rr)
- | VectorPlus n _ l r =>
- let ll := df_deriv l v in
- let rr := df_deriv r v in
- VectorPlus tt ll rr
- | VectorMinus n _ l r =>
- let ll := df_deriv l v in
- let rr := df_deriv r v in
- VectorMinus tt ll rr
- | MatrixPlus n m _ l r =>
- let ll := df_deriv l v in
- let rr := df_deriv r v in
- MatrixPlus tt ll rr
- | MatrixMinus n m _ l r =>
- let ll := df_deriv l v in
- let rr := df_deriv r v in
- MatrixMinus tt ll rr
- | VectorApply n _ x s r =>
- let rr := df_deriv r v in
- let ss := df_deriv s (x, DTfloat) in
- DVector tt (build_vector (fun i => Times tt (VectorElem tt rr i) (df_subst ss (x, DTfloat) (VectorElem tt r i))))
- | MatrixApply n m _ x s r =>
- let rr := df_deriv r v in
- let ss := df_deriv s (x, DTfloat) in
- DMatrix tt (build_matrix (fun i j => Times tt (MatrixElem tt rr i j) (df_subst ss (x, DTfloat) (MatrixElem tt r i j))))
- | VLossfun n _ v1 v2 s l r =>
- let ll := df_deriv l v in
- let ss := df_deriv s (v1, DTfloat) in
- VectorDot tt ll
- (DVector tt (build_vector (fun i =>
- df_subst (df_subst ss (v1, DTfloat) (VectorElem tt l i))
- (v2, DTfloat) (Number tt (vnth r i)))))
- | MLossfun n m _ v1 v2 s l r =>
- let ll := df_deriv l v in
- let ss := df_deriv s (v1, DTfloat) in
- MatrixSum tt
- (DMatrix tt
- (build_matrix (fun i j =>
- (Divide tt
- (Times tt (MatrixElem tt ll i j)
- (df_subst (df_subst ss (v1, DTfloat) (MatrixElem tt l i j))
- (v2, DTfloat) (Number tt (mnth r i j))))
- (Number tt (FfromZ (Z.of_nat m)))))))
- end).
-
- Definition df_gradient {T} (df:DefinedFunction UnitAnn T) (lv:list var_type) : list (DefinedFunction UnitAnn T)
- := map (df_deriv df) lv.
-
- End deriv.
-
- Section eval.
-
- Program
- Fixpoint vartlookup (l:df_env) (a:var_type) :
- option (definition_function_types_interp (snd a))
- := match l with
- | nil => None
- | fv::os => if a == (projT1 fv) then
- Some (eq_rect _ definition_function_types_interp (projT2 fv) _ _)
- else vartlookup os a
- end.
-
- Fixpoint vart_update (l:df_env) (a:var_type) (n:definition_function_types_interp (snd a)) : df_env
- := match l with
- | nil => (mk_env_entry a n)::nil
- | fv::os => if a == (projT1 fv) then
- (mk_env_entry a n)::os else fv::(vart_update os a n)
- end.
-
- Fixpoint df_eval {T Ann} (σ:df_env) (df:DefinedFunction Ann T) : option (definition_function_types_interp T)
- := match df with
- | Number _ r => Some r
- | Constant t _ x => Some x
- | DVector n _ dfs => vectoro_to_ovector (vmap (fun x => df_eval σ x) dfs)
- | DMatrix n m _ df => matrixo_to_omatrix (mmap (fun x => df_eval σ x) df)
- | Var x _ => vartlookup σ x
- | Plus _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (l' + r')
- | _, _ => None
- end
- | Minus _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (l' - r')
- | _, _ => None
- end
- | Times _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (l' * r')
- | _, _ => None
- end
- | Divide _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (l' / r')
- | _, _ => None
- end
- | Square _ e =>
- match df_eval σ e with
- | Some v => Some (v * v)
- | _ => None
- end
- | Exp _ e =>
- match df_eval σ e with
- | Some v => Some (Fexp v)
- | _ => None
- end
- | Log _ e =>
- match df_eval σ e with
- | Some v => Some (Fln v)
- | _ => None
- end
- | Abs _ e =>
- match df_eval σ e with
- | Some v => Some (Fabs v)
- | _ => None
- end
- | Sign _ e =>
- match df_eval σ e with
- | Some v => Some (sign v)
- | _ => None
- end
- | PSign _ e =>
- match df_eval σ e with
- | Some v => Some (pos_sign v)
- | _ => None
- end
- | Max _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (Fmax l' r')
- | _, _ => None
- end
- | VectorElem n _ l i =>
- match (df_eval σ l) with
- | Some l' => Some (vnth l' i)
- | _ => None
- end
- | MatrixElem m n _ l i j =>
- match (df_eval σ l) with
- | Some l' => Some (mnth l' i j)
- | _ => None
- end
- | VectorDot n _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (vsum (vmap2 Fmult l' r'))
- | _, _ => None
- end
- | VectorSum n _ l =>
- match df_eval σ l with
- | Some l' => Some (vsum l')
- | _ => None
- end
- | MatrixSum n m _ l =>
- match df_eval σ l with
- | Some l' => Some (msum l')
- | _ => None
- end
- | VectorScalMult n _ x r =>
- match df_eval σ x, df_eval σ r with
- | Some x', Some r' => Some (vmap (fun rr => x' * rr) r')
- | _, _ => None
- end
- | MatrixScalMult n m _ x r =>
- match df_eval σ x, df_eval σ r with
- | Some x', Some r' => Some (mmap (fun rr => x' * rr) r')
- | _, _ => None
- end
- | MatrixVectorMult n m _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (matrix_vector_mult l' r')
- | _, _ => None
- end
- | MatrixVectorAdd n m _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (matrix_vector_add l' r')
- | _, _ => None
- end
- | MatrixMult n m p _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (matrix_mult l' r')
- | _, _ => None
- end
- | VectorPlus n _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (vmap2 Fplus l' r')
- | _, _ => None
- end
- | VectorMinus n _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (vmap2 Fminus l' r')
- | _, _ => None
- end
- | MatrixPlus n m _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (mmap2 Fplus l' r')
- | _, _ => None
- end
- | MatrixMinus n m _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some l', Some r' => Some (mmap2 Fminus l' r')
- | _, _ => None
- end
- | VectorApply n _ x s r =>
- match df_eval σ r with
- | Some r' => vectoro_to_ovector
- (vmap (fun re =>
- let xv := (x, DTfloat):var_type in
- df_eval (cons (mk_env_entry xv re) nil) s) r')
- | _ => None
- end
- | MatrixApply n m _ x s r =>
- match df_eval σ r with
- | Some r' => matrixo_to_omatrix
- (mmap (fun re =>
- let xv := (x, DTfloat):var_type in
- df_eval (cons (mk_env_entry xv re) nil) s) r')
- | _ => None
- end
- | VLossfun n _ v1 v2 s l r =>
- match df_eval σ l with
- | Some l' =>
- match (vectoro_to_ovector
- (build_vector (fun i =>
- let xv1 := (v1,DTfloat):var_type in
- let xv2 := (v2,DTfloat):var_type in
- df_eval (cons (mk_env_entry xv1 (vnth l' i))
- (cons (mk_env_entry xv2 (vnth r i)) nil)) s))) with
- | Some vv => Some (vsum vv)
- | _ => None
- end
- | _ => None
- end
- | MLossfun n m _ v1 v2 s l r =>
- match df_eval σ l with
- | Some l' =>
- match (matrixo_to_omatrix
- (build_matrix (fun i j =>
- let xv1 := (v1,DTfloat):var_type in
- let xv2 := (v2,DTfloat):var_type in
- df_eval (cons (mk_env_entry xv1 (mnth l' i j))
- (cons (mk_env_entry xv2 (mnth r i j)) nil)) s))) with
- | Some vv => Some ((msum vv) / (FfromZ (Z.of_nat m)))
- | _ => None
- end
- | _ => None
- end
-
- end.
-
- Fixpoint df_eval_tree {T Ann} (σ:df_env) (df:DefinedFunction Ann T) : option (DefinedFunction EvalAnn T)
- := match df with
- | Number _ r => Some (Number r r)
- | Constant t _ x => Some (Constant x x)
- | DVector n _ dfs =>
- match vectoro_to_ovector (vmap (fun x => df_eval_tree σ x) dfs) with
- | Some val => Some (DVector (vmap get_annotation val) val)
- | _ => None
- end
- | DMatrix n m _ df =>
- match matrixo_to_omatrix (mmap (fun x => df_eval_tree σ x) df) with
- | Some val => Some (DMatrix
- (vmap (fun x => vmap get_annotation x) val) val)
- | _ => None
- end
- | Var x _ => match vartlookup σ x with
- | Some val => Some (Var x val)
- | _ => None
- end
- | Plus _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => Some (Plus ((get_annotation l') + (get_annotation r'))
- l' r')
- | _, _ => None
- end
- | Minus _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => Some (Minus ((get_annotation l') - (get_annotation r'))
- l' r')
- | _, _ => None
- end
- | Times _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => Some (Times ((get_annotation l') * (get_annotation r'))
- l' r')
- | _, _ => None
- end
- | Divide _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => Some (Divide ((get_annotation l') / (get_annotation r'))
- l' r')
- | _, _ => None
- end
- | Square _ e =>
- match df_eval_tree σ e with
- | Some vv => let v := get_annotation vv in Some (Square (v * v) vv)
- | _ => None
- end
- | Exp _ e =>
- match df_eval_tree σ e with
- | Some vv => Some (Exp (Fexp (get_annotation vv)) vv)
- | _ => None
- end
- | Log _ e =>
- match df_eval_tree σ e with
- | Some vv => Some (Log (Fln (get_annotation vv)) vv)
- | _ => None
- end
- | Abs _ e =>
- match df_eval_tree σ e with
- | Some vv => Some (Abs (Fabs (get_annotation vv)) vv)
- | _ => None
- end
- | Sign _ e =>
- match df_eval_tree σ e with
- | Some vv => Some (Sign (sign (get_annotation vv)) vv)
- | _ => None
- end
- | PSign _ e =>
- match df_eval_tree σ e with
- | Some vv => Some (PSign (pos_sign (get_annotation vv)) vv)
- | _ => None
- end
- | Max _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => Some (Max (Fmax (get_annotation l') (get_annotation r'))
- l' r')
- | _, _ => None
- end
- | VectorElem n _ l i =>
- match (df_eval_tree σ l) with
- | Some l' => let vl' := get_annotation l' in
- Some (VectorElem (vnth vl' i) l' i)
- | _ => None
- end
- | MatrixElem m n _ l i j =>
- match (df_eval_tree σ l) with
- | Some l' => let vl' := get_annotation l' in
- Some (MatrixElem (mnth vl' i j) l' i j)
- | _ => None
- end
- | VectorDot n _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => let vl' := get_annotation l' in
- let vr' := get_annotation r' in
- Some (VectorDot (vsum (build_vector (fun i => (vnth vl' i) * (vnth vr' i)))) l' r')
- | _, _ => None
- end
- | VectorSum n _ l =>
- match df_eval_tree σ l with
- | Some l' => let vl' := get_annotation l' in
- Some (VectorSum (vsum vl') l')
- | _ => None
- end
- | MatrixSum n m _ l =>
- match df_eval_tree σ l with
- | Some l' => let vl' := get_annotation l' in
- Some (MatrixSum (msum vl') l')
- | _ => None
- end
- | VectorScalMult n _ x r =>
- match df_eval_tree σ x, df_eval_tree σ r with
- | Some x', Some r' => let vx' := get_annotation x' in
- let vr' := get_annotation r' in
- let vec : vector float n := vmap (fun re => vx' * re) vr' in
- Some (VectorScalMult vec x' r')
- | _, _ => None
- end
- | MatrixScalMult n m _ x r =>
- match df_eval_tree σ x, df_eval_tree σ r with
- | Some x', Some r' => let vx' := get_annotation x' in
- let vr' := get_annotation r' in
- let mat : matrix float n m := mmap (fun re => vx' * re) vr' in
- Some (MatrixScalMult mat x' r')
- | _, _ => None
- end
- | MatrixVectorMult n m _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => let vl' := get_annotation l' in
- let vr' := get_annotation r' in
- Some (MatrixVectorMult (matrix_vector_mult vl' vr') l' r')
- | _, _ => None
- end
- | MatrixVectorAdd n m _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => let vl' := get_annotation l' in
- let vr' := get_annotation r' in
- Some (MatrixVectorAdd (matrix_vector_add vl' vr') l' r')
- | _, _ => None
- end
- | MatrixMult n m p _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => let vl' := get_annotation l' in
- let vr' := get_annotation r' in
- Some (MatrixMult (matrix_mult vl' vr') l' r')
- | _, _ => None
- end
- | VectorPlus n _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => let vl' := get_annotation l' in
- let vr' := get_annotation r' in
- let vec : vector float n :=
- vmap2 Fplus vl' vr' in
- Some (VectorPlus vec l' r')
- | _, _ => None
- end
- | VectorMinus n _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => let vl' := get_annotation l' in
- let vr' := get_annotation r' in
- let vec : vector float n :=
- vmap2 Fminus vl' vr' in
- Some (VectorMinus vec l' r')
- | _, _ => None
- end
- | MatrixPlus n m _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => let vl' := get_annotation l' in
- let vr' := get_annotation r' in
- let mat : matrix float n m :=
- mmap2 Fplus vl' vr' in
- Some (MatrixPlus mat l' r')
- | _, _ => None
- end
- | MatrixMinus n m _ l r =>
- match df_eval_tree σ l, df_eval_tree σ r with
- | Some l', Some r' => let vl' := get_annotation l' in
- let vr' := get_annotation r' in
- let mat : matrix float n m :=
- mmap2 Fminus vl' vr' in
- Some (MatrixMinus mat l' r')
- | _, _ => None
- end
-
- | VectorApply n _ x s r =>
- match df_eval_tree σ r with
- | Some r' =>
- let vr' := get_annotation r' in
- match vectoro_to_ovector
- (build_vector (fun i =>
- let xv := (x, DTfloat):var_type in
- df_eval (cons (mk_env_entry xv (vnth vr' i)) nil) s)) with
- | Some val => Some (VectorApply val x s r')
- | _ => None
- end
- | _ => None
- end
- | MatrixApply n m _ x s r =>
- match df_eval_tree σ r with
- | Some r' =>
- let vr' := get_annotation r' in
- match matrixo_to_omatrix
- (build_matrix (fun i j =>
- let xv := (x, DTfloat):var_type in
- df_eval (cons (mk_env_entry xv (mnth vr' i j)) nil) s)) with
- | Some val => Some (MatrixApply val x s r')
- | _ => None
- end
- | _ => None
- end
- | VLossfun n _ v1 v2 s l r =>
- match df_eval_tree σ l with
- | Some l' =>
- let vl' := get_annotation l' in
- match (vectoro_to_ovector
- (build_vector (fun i =>
- let xv1 := (v1,DTfloat):var_type in
- let xv2 := (v2,DTfloat):var_type in
- df_eval (cons (mk_env_entry xv1 (vnth vl' i))
- (cons (mk_env_entry xv2 (vnth r i)) nil)) s))) with
- | Some vv => Some (VLossfun (vsum vv) v1 v2 s l' r)
- | _ => None
- end
- | _ => None
- end
- | MLossfun n m _ v1 v2 s l r =>
- match df_eval_tree σ l with
- | Some l' =>
- let vl' := get_annotation l' in
- match (matrixo_to_omatrix
- (build_matrix (fun i j =>
- let xv1 := (v1,DTfloat):var_type in
- let xv2 := (v2,DTfloat):var_type in
- df_eval (cons (mk_env_entry xv1 (mnth vl' i j))
- (cons (mk_env_entry xv2 (mnth r i j)) nil)) s))) with
- | Some vv => Some (MLossfun ((msum vv)/(FfromZ (Z.of_nat m))) v1 v2 s l' r)
- | _ => None
- end
- | _ => None
- end
- end.
-
- Definition eval_env_entry_type := {T:definition_function_types & (DefinedFunction UnitAnn T) & definition_function_types_interp T}.
- Definition df_eval_env := list eval_env_entry_type.
-
- Definition mk_eval_env_entry {T} df val : eval_env_entry_type
- := let P := fun t => DefinedFunction UnitAnn t in
- let Q := fun t => definition_function_types_interp t in
- existT2 P Q T df val.
-
- Definition pair_update_evals {T} (df:DefinedFunction UnitAnn T) (val:definition_function_types_interp T) (dfevals : df_eval_env) : (definition_function_types_interp T * df_eval_env) :=
- (val, (mk_eval_env_entry df val)::dfevals).
-
- Fixpoint df_evals_list {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (dfevals : df_eval_env) : option (definition_function_types_interp T * df_eval_env)
- := match df with
- | Number _ r => Some (pair_update_evals (Number tt r) r dfevals)
- | Constant t _ x => Some (pair_update_evals (Constant tt x) x dfevals)
- | DVector n _ dfs => None (*vectoro_to_ovector (fun i => df_eval σ (dfs i))*)
- | DMatrix n m _ df => None (*matrixo_to_omatrix (fun i j => df_eval σ (df i j))*)
- | Var x _ =>
- match vartlookup σ x with
- | Some val => Some (pair_update_evals (Var x tt) val dfevals)
- | _ => None
- end
- | Plus _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') => Some (pair_update_evals (Plus tt l r) (l'+r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | Minus _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') => Some (pair_update_evals (Minus tt l r) (l'-r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | Times _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') => Some (pair_update_evals (Times tt l r) (l'*r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | Divide _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') => Some (pair_update_evals (Divide tt l r) (l'/r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | Square _ e =>
- match df_evals_list σ e dfevals with
- | Some (v, dfevals') => Some (pair_update_evals (Square tt e) (v * v) dfevals')
- | _ => None
- end
- | Exp _ e =>
- match df_evals_list σ e dfevals with
- | Some (v, dfevals') => Some (pair_update_evals (Exp tt e) (Fexp v) dfevals')
- | _ => None
- end
- | Log _ e =>
- match df_evals_list σ e dfevals with
- | Some (v, dfevals') => Some (pair_update_evals (Log tt e) (Fln v) dfevals')
- | _ => None
- end
- | Abs _ e =>
- match df_evals_list σ e dfevals with
- | Some (v, dfevals') => Some (pair_update_evals (Abs tt e) (Fabs v) dfevals')
- | _ => None
- end
- | Sign _ e =>
- match df_evals_list σ e dfevals with
- | Some (v, dfevals') => Some (pair_update_evals (Sign tt e) (sign v) dfevals')
- | _ => None
- end
- | PSign _ e =>
- match df_evals_list σ e dfevals with
- | Some (v, dfevals') => Some (pair_update_evals (PSign tt e) (pos_sign v) dfevals')
- | _ => None
- end
- | Max _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') => Some (pair_update_evals (Max tt l r) (Fmax l' r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | VectorElem n _ l i =>
- match (df_evals_list σ l dfevals) with
- | Some (l', dfevals') => Some (pair_update_evals (VectorElem tt l i) (vnth l' i) dfevals')
- | _ => None
- end
- | MatrixElem m n _ l i j =>
- match (df_evals_list σ l dfevals) with
- | Some (l', dfevals') => Some (pair_update_evals (MatrixElem tt l i j) (mnth l' i j) dfevals')
- | _ => None
- end
- | VectorDot n _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (VectorDot tt l r)
- (vsum (vmap2 Fmult l' r')) dfevals'')
- | _ => None
- end
- | _ => None
- end
- | VectorSum n _ l =>
- match df_evals_list σ l dfevals with
- | Some (l',dfevals') => Some (pair_update_evals (VectorSum tt l) (vsum l') dfevals')
- | _ => None
- end
- | MatrixSum n m _ l =>
- match df_evals_list σ l dfevals with
- | Some (l',dfevals') => Some (pair_update_evals (MatrixSum tt l) (msum l') dfevals')
- | _ => None
- end
- | VectorScalMult n _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (VectorScalMult tt l r)
- (vmap (fun re => l' * re) r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | MatrixScalMult n m _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (MatrixScalMult tt l r)
- (mmap (fun re => l' * re) r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | MatrixVectorMult n m _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (MatrixVectorMult tt l r)
- (matrix_vector_mult l' r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | MatrixVectorAdd n m _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (MatrixVectorAdd tt l r)
- (matrix_vector_add l' r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | MatrixMult n m p _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (MatrixMult tt l r) (matrix_mult l' r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | VectorPlus n _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (VectorPlus tt l r)
- (vmap2 Fplus l' r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | VectorMinus n _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (VectorMinus tt l r)
- (vmap2 Fminus l' r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | MatrixPlus n m _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (MatrixPlus tt l r)
- (mmap2 Fplus l' r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | MatrixMinus n m _ l r =>
- match df_evals_list σ l dfevals with
- | Some (l', dfevals') =>
- match df_evals_list σ r dfevals' with
- Some (r', dfevals'') =>
- Some (pair_update_evals (MatrixMinus tt l r)
- (mmap2 Fminus l' r') dfevals'')
- | _ => None
- end
- | _ => None
- end
- | VectorApply n _ x s r =>
- match df_evals_list σ r dfevals with
-(* | Some r' => vectoro_to_ovector
- (fun i =>
- let xv := (x, DTfloat):var_type in
- df_eval (cons (mk_env_entry xv (r' i)) σ) s) *)
- | _ => None
- end
- | VLossfun n _ v1 v2 s l r =>
- match df_evals_list σ l dfevals with
-(* | Some l' =>
- match (vectoro_to_ovector
- (fun i =>
- let xv1 := (v1,DTfloat):var_type in
- let xv2 := (v2,DTfloat):var_type in
- df_eval (cons (mk_env_entry xv1 (l' i))
- (cons (mk_env_entry xv2 (r i)) σ)) s)) with
- | Some vv => Some (vsum vv)
- | _ => None
- end *)
- | _ => None
- end
- | _ => None
- end.
-
-(*
- Program
- Fixpoint evalslookup {T} (l:df_eval_env) (df:DefinedFunction UnitAnn T) :
- option (definition_function_types_interp T)
- := match l with
- | nil => None
- | fv::os => if T == (projT1 (sigT_of_sigT2 fv)) then
- if df == (projT2 (sigT_of_sigT2 fv)) then
- Some (eq_rect _ definition_function_types_interp (projT3 fv) _ _)
- else evalslookup os df
- else evalslookup os df
- end.
-*)
- Definition df_eval_symbolic_gradient {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (lv:list var_type) : option (list (definition_function_types_interp T))
- := listo_to_olist (map (df_eval σ) (df_gradient df lv)).
-
- End eval.
-
- Section isderiv.
-
- Context (σ:df_env).
- Context (v:SubVar).
-(*
- Inductive is_deriv : DefinedFunction -> float -> Prop
- :=
- | is_deriv_Number (x : float) : is_deriv (Number x) 0
- | is_deriv_Var_eq : is_deriv (Var v) 1
- | is_deriv_Var_neq (sv : SubVar) : sv <> v -> is_deriv (Var sv) 0
- | is_deriv_Plus l l' r r' :
- is_deriv l l' ->
- is_deriv r r' ->
- is_deriv (Plus l r) (l' + r')
- | is_deriv_Minus l l' r r' :
- is_deriv l l' ->
- is_deriv r r' ->
- is_deriv (Minus l r) (l' - r')
- | is_deriv_Times l le l' r re r' :
- df_eval σ l = Some le ->
- is_deriv l l' ->
- df_eval σ r = Some re ->
- is_deriv r r' ->
- is_deriv (Times l r) ((le * r') + (l' * re))
- | is_deriv_Divide l le l' r re r' :
- df_eval σ l = Some le ->
- is_deriv l l' ->
- df_eval σ r = Some re ->
- is_deriv r r' ->
- is_deriv (Times l r)
- (((l' * re ) - (le * r'))
- / (re * re))
- | is_deriv_Exp e ee e' :
- df_eval σ e = Some ee ->
- is_deriv e e' ->
- is_deriv (Exp e) (e' * (Fexp ee))
- | is_deriv_Log e ee e' :
- df_eval σ e = Some ee ->
- is_deriv e e' ->
- is_deriv (Exp e) (e' / ee)
- | is_deriv_Abs e ee e' :
- df_eval σ e = Some ee ->
- is_deriv e e' -> is_deriv (Abs e) (e' * (sign ee))
- | is_deriv_Sign (e : DefinedFunction) :
- is_deriv (Sign e) 0
- | is_deriv_PSign (e : DefinedFunction) :
- is_deriv (PSign e) 0
- | is_deriv_Max_l l le l' re r :
- df_eval σ l = Some le ->
- df_eval σ r = Some re ->
- (le > re) = true ->
- is_deriv l l' ->
- is_deriv (Max l r) l'
- | is_deriv_Max_r l le r re r' :
- df_eval σ l = Some le ->
- df_eval σ r = Some re ->
- (re >= le) = true ->
- is_deriv r r' ->
- is_deriv (Max l r) r'.
- (*
- | is_deriv_Max_eq l l' ee r r' :
- df_eval σ l = Some ee ->
- df_eval σ r = Some ee ->
- is_deriv l l' ->
- is_deriv r r' ->
- is_deriv (Max l r) ((l' + r')/2) *)
-
-*)
- End isderiv.
-
- Section deriv2.
-
- Fixpoint df_eval_deriv {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v:var_type) : option (definition_function_types_interp T)
- := (match df with
- | Number _ _ => Some 0
- | Constant t _ x => Some
- match t return definition_function_types_interp t with
- | DTfloat => 0
- | DTVector n => ConstVector n 0
- | DTMatrix m n => ConstMatrix m n 0
- end
- | DVector n _ dfs => vectoro_to_ovector (vmap (fun xe => df_eval_deriv σ xe v) dfs)
- | DMatrix n m _ df => matrixo_to_omatrix (mmap (fun xe => df_eval_deriv σ xe v) df)
- | Var x _ => Some (let t:=snd x in
- match t return definition_function_types_interp t with
- | DTfloat => if x == v then 1 else 0
- | DTVector n => ConstVector n (if x == v then 1 else 0)
- | DTMatrix m n => ConstMatrix m n (if x == v then 1 else 0)
- end)
- | Plus _ l r =>
- match df_eval_deriv σ l v, df_eval_deriv σ r v with
- | Some le, Some lr => Some (le + lr)
- | _, _ => None
- end
- | Minus _ l r =>
- match df_eval_deriv σ l v, df_eval_deriv σ r v with
- | Some le, Some lr => Some (le - lr)
- | _, _ => None
- end
- | Times _ l r =>
- match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (le * rd +
- (ld * re))
- | _, _, _, _ => None
- end
- | Divide _ l r =>
- match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some ((ld / re) - ((le * rd) / (re * re)))
- | _, _, _, _ => None
- end
- | Square _ e =>
- match df_eval σ e, df_eval_deriv σ e v with
- | Some ee, Some ed => Some (2 * ee * ed)
- | _, _ => None
- end
- | Exp _ e =>
- match df_eval σ e, df_eval_deriv σ e v with
- | Some ee, Some ed => Some (ed * Fexp ee)
- | _, _ => None
- end
- | Log _ e =>
- match df_eval σ e, df_eval_deriv σ e v with
- | Some ee, Some ed => Some (ed / ee)
- | _, _ => None
- end
- | Abs _ e =>
- match df_eval σ e, df_eval_deriv σ e v with
- | Some ee, Some ed => Some (ed * (sign ee))
- | _, _ => None
- end
- | Sign _ e =>
- match df_eval_deriv σ e v with
- | Some _ => Some 0
- | None => None
- end
- | PSign _ e =>
- match df_eval_deriv σ e v with
- | Some _ => Some 0
- | None => None
- end
- | Max _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some le, Some re =>
- if le <= re then df_eval_deriv σ r v else df_eval_deriv σ l v
- | _, _ => None
- end
- | VectorElem n _ l i =>
- match (df_eval_deriv σ l v) with
- | Some l' => Some (vnth l' i)
- | _ => None
- end
- | MatrixElem m n _ l i j =>
- match (df_eval_deriv σ l v) with
- | Some l' => Some (mnth l' i j)
- | _ => None
- end
- | VectorDot n _ l r =>
- match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (Fplus (vsum (vmap2 Fmult le rd)) (vsum (vmap2 Fmult ld re)))
- | _, _, _, _ => None
- end
- | VectorSum n _ l =>
- match df_eval_deriv σ l v with
- | Some ld =>
- Some (vsum ld)
- | _ => None
- end
- | MatrixSum n m _ l =>
- match df_eval_deriv σ l v with
- | Some ld =>
- Some (msum ld)
- | _ => None
- end
- | VectorScalMult n _ x r =>
- match df_eval σ x, df_eval_deriv σ x v, df_eval σ r, df_eval_deriv σ r v with
- | Some xe, Some xd, Some re, Some rd =>
- Some (vmap2 (fun rd1 re1 => xe * rd1 + xd * re1) rd re)
- | _, _, _, _ => None
- end
- | MatrixScalMult n m _ x r =>
- match df_eval σ x, df_eval_deriv σ x v, df_eval σ r, df_eval_deriv σ r v with
- | Some xe, Some xd, Some re, Some rd =>
- Some (mmap2 (fun rd1 re1 => xe * rd1 + xd * re1) rd re)
- | _, _, _, _ => None
- end
- | MatrixVectorMult n m _ l r =>
- match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (build_vector (fun i => vsum (vmap4 (fun lei rde ldi ree => lei * rde + ldi * ree)
- (vnth le i) rd (vnth ld i) re)))
- | _, _, _, _ => None
- end
- | MatrixVectorAdd n m _ l r =>
- match df_eval_deriv σ l v, df_eval_deriv σ r v with
- | Some ld, Some rd =>
- Some (build_matrix (fun i j => (mnth ld i j) + (vnth rd i)))
- | _, _ => None
- end
- | MatrixMult n m p _ l r =>
- match df_eval σ l, df_eval_deriv σ l v, df_eval σ r, df_eval_deriv σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (build_matrix (fun i k => vsum (build_vector (fun j => (mnth le i j)*(mnth rd j k) + (mnth ld i j)*(mnth re j k)))))
- | _, _, _, _ => None
- end
- | VectorPlus n _ l r =>
- match df_eval_deriv σ l v, df_eval_deriv σ r v with
- | Some l', Some r' => Some (vmap2 Fplus l' r')
- | _, _ => None
- end
- | VectorMinus n _ l r =>
- match df_eval_deriv σ l v, df_eval_deriv σ r v with
- | Some l', Some r' => Some (vmap2 Fminus l' r')
- | _, _ => None
- end
- | MatrixPlus n m _ l r =>
- match df_eval_deriv σ l v, df_eval_deriv σ r v with
- | Some l', Some r' => Some (mmap2 Fplus l' r')
- | _, _ => None
- end
- | MatrixMinus n m _ l r =>
- match df_eval_deriv σ l v, df_eval_deriv σ r v with
- | Some l', Some r' => Some (mmap2 Fminus l' r')
- | _, _ => None
- end
- | VectorApply n _ x s r =>
- match df_eval σ r, df_eval_deriv σ r v with
- | Some re, Some rd =>
- vectoro_to_ovector
- (build_vector (fun i =>
- let xv := (x, DTfloat):var_type in
- match df_eval_deriv (cons (mk_env_entry xv (vnth re i)) nil) s xv with
- | Some sd => Some ((vnth rd i) * sd)
- | _ => None
- end))
- | _, _ => None
- end
- | MatrixApply n m _ x s r =>
- match df_eval σ r, df_eval_deriv σ r v with
- | Some re, Some rd =>
- matrixo_to_omatrix
- (build_matrix (fun i j =>
- let xv := (x, DTfloat):var_type in
- match df_eval_deriv (cons (mk_env_entry xv (mnth re i j)) nil) s xv with
- | Some sd => Some ((mnth rd i j) * sd)
- | _ => None
- end))
- | _, _ => None
- end
- | VLossfun n _ v1 v2 s l r =>
- match df_eval σ l, df_eval_deriv σ l v with
- | Some le, Some ld =>
- match (vectoro_to_ovector
- (build_vector (fun i =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- match df_eval_deriv (cons (mk_env_entry xv1 (vnth le i))
- (cons (mk_env_entry xv2 (vnth r i)) nil)) s xv1 with
- | Some sd => Some ((vnth ld i) * sd)
- | _ => None
- end))) with
- | Some vv => Some (vsum vv)
- | _ => None
- end
- | _, _ => None
- end
- | MLossfun n m _ v1 v2 s l r =>
- match df_eval σ l, df_eval_deriv σ l v with
- | Some le, Some ld =>
- match (matrixo_to_omatrix
- (build_matrix (fun i j =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- match df_eval_deriv (cons (mk_env_entry xv1 (mnth le i j))
- (cons (mk_env_entry xv2 (mnth r i j)) nil)) s xv1 with
- | Some sd => Some ((mnth ld i j) * sd)
- | _ => None
- end))) with
- | Some vv => Some ((msum vv)/(FfromZ (Z.of_nat m)))
- | _ => None
- end
- | _, _ => None
- end
- end).
-
- Definition mk_genvar_env (s:SubVar) := mk_env_entry (s, DTfloat) (FfromZ (Z.of_nat 1)) :: nil.
-
- (* the v environment below pairs variables with their derivatives *)
- (* in some sense this is giving a directional derivative defined by v *)
- Fixpoint df_eval_deriv_genvar {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v:df_env) : option (definition_function_types_interp T)
- := (match df with
- | Number _ _ => Some 0
- | Constant t _ x => Some
- match t return definition_function_types_interp t with
- | DTfloat => 0
- | DTVector n => ConstVector n 0
- | DTMatrix m n => ConstMatrix m n 0
- end
- | DVector n _ dfs => vectoro_to_ovector (vmap (fun df1 => df_eval_deriv_genvar σ df1 v) dfs)
- | DMatrix n m _ df => matrixo_to_omatrix (mmap (fun df1 => df_eval_deriv_genvar σ df1 v) df)
- | Var x _ => Some (
- match vartlookup v x with
- | Some val => val
- | _ =>
- match (snd x) with
- | DTfloat => 0
- | DTVector n => ConstVector n 0
- | DTMatrix m n => ConstMatrix m n 0
- end
- end)
- | Plus _ l r =>
- match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with
- | Some le, Some lr => Some (le + lr)
- | _, _ => None
- end
- | Minus _ l r =>
- match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with
- | Some le, Some lr => Some (le - lr)
- | _, _ => None
- end
- | Times _ l r =>
- match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (le * rd +
- (ld * re))
- | _, _, _, _ => None
- end
- | Divide _ l r =>
- match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some ((ld / re) - ((le * rd) / (re * re)))
- | _, _, _, _ => None
- end
- | Square _ e =>
- match df_eval σ e, df_eval_deriv_genvar σ e v with
- | Some ee, Some ed => Some (2 * ee * ed)
- | _, _ => None
- end
- | Exp _ e =>
- match df_eval σ e, df_eval_deriv_genvar σ e v with
- | Some ee, Some ed => Some (ed * Fexp ee)
- | _, _ => None
- end
- | Log _ e =>
- match df_eval σ e, df_eval_deriv_genvar σ e v with
- | Some ee, Some ed => Some (ed / ee)
- | _, _ => None
- end
- | Abs _ e =>
- match df_eval σ e, df_eval_deriv_genvar σ e v with
- | Some ee, Some ed => Some (ed * (sign ee))
- | _, _ => None
- end
- | Sign _ e =>
- match df_eval_deriv_genvar σ e v with
- | Some _ => Some 0
- | None => None
- end
- | PSign _ e =>
- match df_eval_deriv_genvar σ e v with
- | Some _ => Some 0
- | None => None
- end
- | Max _ l r =>
- match df_eval σ l, df_eval σ r with
- | Some le, Some re =>
- if le <= re then df_eval_deriv_genvar σ r v else df_eval_deriv_genvar σ l v
- | _, _ => None
- end
- | VectorElem n _ l i =>
- match (df_eval_deriv_genvar σ l v) with
- | Some l' => Some (vnth l' i)
- | _ => None
- end
- | MatrixElem m n _ l i j =>
- match (df_eval_deriv_genvar σ l v) with
- | Some l' => Some (mnth l' i j)
- | _ => None
- end
- | VectorDot n _ l r =>
- match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (vsum (vmap4 (fun le1 rd1 ld1 re1 => le1 * rd1 + ld1 * re1) le rd ld re))
- | _, _, _, _ => None
- end
- | VectorSum n _ l =>
- match df_eval_deriv_genvar σ l v with
- | Some ld =>
- Some (vsum ld)
- | _ => None
- end
- | MatrixSum n m _ l =>
- match df_eval_deriv_genvar σ l v with
- | Some ld =>
- Some (msum ld)
- | _ => None
- end
- | VectorScalMult n _ x r =>
- match df_eval σ x, df_eval_deriv_genvar σ x v, df_eval σ r, df_eval_deriv_genvar σ r v with
- | Some xe, Some xd, Some re, Some rd =>
- Some (vmap2 (fun rd1 re1 => xe * rd1 + xd * re1) rd re)
- | _, _, _, _ => None
- end
- | MatrixScalMult n m _ x r =>
- match df_eval σ x, df_eval_deriv_genvar σ x v, df_eval σ r, df_eval_deriv_genvar σ r v with
- | Some xe, Some xd, Some re, Some rd =>
- Some (mmap2 (fun rd1 re1 => xe * rd1 + xd * re1) rd re)
- | _, _, _, _ => None
- end
- | MatrixVectorMult n m _ l r =>
- match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (build_vector (fun i => vsum (vmap4 (fun lei rde ldi ree => lei * rde + ldi * ree)
- (vnth le i) rd (vnth ld i) re)))
- | _, _, _, _ => None
- end
- | MatrixVectorAdd n m _ l r =>
- match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with
- | Some ld, Some rd =>
- Some (build_matrix (fun i j => (mnth ld i j) + (vnth rd i)))
- | _, _ => None
- end
- | MatrixMult n m p _ l r =>
- match df_eval σ l, df_eval_deriv_genvar σ l v, df_eval σ r, df_eval_deriv_genvar σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (build_matrix (fun i k => vsum (build_vector (fun j => (mnth le i j)*(mnth rd j k) + (mnth ld i j)*(mnth re j k)))))
- | _, _, _, _ => None
- end
- | VectorPlus n _ l r =>
- match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with
- | Some l', Some r' => Some (vmap2 Fplus l' r')
- | _, _ => None
- end
- | VectorMinus n _ l r =>
- match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with
- | Some l', Some r' => Some (vmap2 Fminus l' r')
- | _, _ => None
- end
- | MatrixPlus n m _ l r =>
- match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with
- | Some l', Some r' => Some (mmap2 Fplus l' r')
- | _, _ => None
- end
- | MatrixMinus n m _ l r =>
- match df_eval_deriv_genvar σ l v, df_eval_deriv_genvar σ r v with
- | Some l', Some r' => Some (mmap2 Fminus l' r')
- | _, _ => None
- end
- | VectorApply n _ x s r =>
- match df_eval σ r, df_eval_deriv_genvar σ r v with
- | Some re, Some rd =>
- vectoro_to_ovector
- (build_vector (fun i =>
- let xv := (x, DTfloat):var_type in
- match df_eval_deriv_genvar (mk_env_entry xv (vnth re i) :: nil) s
- (mk_genvar_env x) with
- | Some sd => Some ((vnth rd i) * sd)
- | _ => None
- end))
- | _, _ => None
- end
- | MatrixApply n m _ x s r =>
- match df_eval σ r, df_eval_deriv_genvar σ r v with
- | Some re, Some rd =>
- matrixo_to_omatrix
- (build_matrix (fun i j =>
- let xv := (x, DTfloat):var_type in
- match df_eval_deriv_genvar (mk_env_entry xv (mnth re i j) :: nil) s
- (mk_genvar_env x) with
- | Some sd => Some ((mnth rd i j) * sd)
- | _ => None
- end))
- | _, _ => None
- end
- | VLossfun n _ v1 v2 s l r =>
- match df_eval σ l, df_eval_deriv_genvar σ l v with
- | Some le, Some ld =>
- match (vectoro_to_ovector
- (build_vector (fun i =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- match df_eval_deriv_genvar ( mk_env_entry xv1 (vnth le i) ::
- mk_env_entry xv2 (vnth r i) :: nil) s
- (mk_genvar_env v1) with
- | Some sd => Some ((vnth ld i) * sd)
- | _ => None
- end))) with
- | Some vv => Some (vsum vv)
- | _ => None
- end
- | _, _ => None
- end
- | MLossfun n m _ v1 v2 s l r =>
- match df_eval σ l, df_eval_deriv_genvar σ l v with
- | Some le, Some ld =>
- match (matrixo_to_omatrix
- (build_matrix (fun i j =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- match df_eval_deriv_genvar ( mk_env_entry xv1 (mnth le i j) ::
- mk_env_entry xv2 (mnth r i j) :: nil) s
- ( mk_genvar_env v1) with
- | Some sd => Some ((mnth ld i j) * sd)
- | _ => None
- end))) with
- | Some vv => Some ((msum vv)/(FfromZ (Z.of_nat m)))
- | _ => None
- end
- | _, _ => None
- end
- end).
-
-
- Definition definition_function_types_interp_prod (vart dft:definition_function_types) : Type
- := match vart with
- | DTfloat => definition_function_types_interp dft
- | DTVector n => Vector (definition_function_types_interp dft) n
- | DTMatrix m n => Matrix (definition_function_types_interp dft) m n
- end.
-
-
- Definition UnitVector (n:nat) (j : {n':nat | (n' < n)%nat}) : vector float n :=
- build_vector (fun i => if (proj1_sig i) == (proj1_sig j) then 1 else 0).
-
- Definition UnitMatrix (n m: nat)
- (i : {n':nat | (n' < n)%nat})
- (j : {m':nat | (m' < m)%nat}) : matrix float n m :=
- build_matrix (fun a b => if (proj1_sig a) == (proj1_sig i) then
- (if (proj1_sig b) == (proj1_sig j) then 1 else 0)
- else 0).
-
- Definition const_env (v : var_type) : df_env
- := match (snd v) with
- | DTfloat => ((mk_env_entry (fst v, DTfloat) 0)::nil)
- | DTVector n => ((mk_env_entry (fst v, DTVector n) (ConstVector n 0))::nil)
- | DTMatrix m n => ((mk_env_entry (fst v, DTMatrix m n) (ConstMatrix m n 0))::nil)
- end.
-
- Fixpoint df_eval_tree_deriv {T} (σ:df_env) (df:DefinedFunction EvalAnn T) (v:var_type) : option (definition_function_types_interp T)
- := (match df with
- | Number _ _ => Some 0
- | Constant t _ x => Some
- match t return definition_function_types_interp t with
- | DTfloat => 0
- | DTVector n => ConstVector n 0
- | DTMatrix m n => ConstMatrix m n 0
- end
- | DVector n _ dfs => vectoro_to_ovector (vmap (fun e => df_eval_tree_deriv σ e v) dfs)
- | DMatrix n m _ df => matrixo_to_omatrix (mmap (fun e => df_eval_tree_deriv σ e v) df)
- | Var x _ => Some (let t:=snd x in
- match t return definition_function_types_interp t with
- | DTfloat => if x == v then 1 else 0
- | DTVector n => ConstVector n (if x == v then 1 else 0)
- | DTMatrix m n => ConstMatrix m n (if x == v then 1 else 0)
- end)
- | Plus _ l r =>
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some le, Some lr => Some (le + lr)
- | _, _ => None
- end
- | Minus _ l r =>
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some le, Some lr => Some (le - lr)
- | _, _ => None
- end
- | Times _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some ld, Some rd =>
- Some (le * rd +
- (ld * re))
- | _, _ => None
- end
- | Divide _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some ld, Some rd =>
- Some ((ld / re) - ((le * rd) / (re * re)))
- | _, _ => None
- end
- | Square _ e =>
- let ee := get_annotation e in
- match df_eval_tree_deriv σ e v with
- | Some ed => Some (2 * ee * ed)
- | _ => None
- end
- | Exp _ e =>
- let ee := get_annotation e in
- match df_eval_tree_deriv σ e v with
- | Some ed => Some (ed * Fexp ee)
- | _ => None
- end
- | Log _ e =>
- let ee := get_annotation e in
- match df_eval_tree_deriv σ e v with
- | Some ed => Some (ed / ee)
- | _ => None
- end
- | Abs _ e =>
- let ee := get_annotation e in
- match df_eval_tree_deriv σ e v with
- | Some ed => Some (ed * (sign ee))
- | _ => None
- end
- | Sign _ e =>
- match df_eval_tree_deriv σ e v with
- | Some _ => Some 0
- | None => None
- end
- | PSign _ e =>
- match df_eval_tree_deriv σ e v with
- | Some _ => Some 0
- | None => None
- end
- | Max _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- if le <= re then df_eval_tree_deriv σ r v else df_eval_tree_deriv σ l v
- | VectorElem n _ l i =>
- match (df_eval_tree_deriv σ l v) with
- | Some l' => Some (vnth l' i)
- | _ => None
- end
- | MatrixElem m n _ l i j =>
- match (df_eval_tree_deriv σ l v) with
- | Some l' => Some (mnth l' i j)
- | _ => None
- end
- | VectorDot n _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some ld, Some rd =>
- Some (Fplus (vsum (vmap2 Fmult le rd)) (vsum (vmap2 Fmult ld re)))
- | _, _ => None
- end
- | VectorSum n _ l =>
- match df_eval_tree_deriv σ l v with
- | Some ld =>
- Some (vsum ld)
- | _ => None
- end
- | MatrixSum n m _ l =>
- match df_eval_tree_deriv σ l v with
- | Some ld =>
- Some (msum ld)
- | _ => None
- end
- | VectorScalMult n _ x r =>
- let '(xe,re) := (get_annotation x, get_annotation r) in
- match df_eval_tree_deriv σ x v, df_eval_tree_deriv σ r v with
- | Some xd, Some rd =>
- Some (vmap2 (fun rd1 re1 => xe * rd1 + xd * re1) rd re)
- | _, _ => None
- end
- | MatrixScalMult n m _ x r =>
- let '(xe,re) := (get_annotation x, get_annotation r) in
- match df_eval_tree_deriv σ x v, df_eval_tree_deriv σ r v with
- | Some xd, Some rd =>
- Some (mmap2 (fun rd1 re1 => xe * rd1 + xd * re1) rd re)
- | _, _ => None
- end
- | MatrixVectorMult n m _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some ld, Some rd =>
- Some (build_vector (fun i => vsum (vmap4 (fun lei rde ldi ree => lei * rde + ldi * ree)
- (vnth le i) rd (vnth ld i) re)))
- | _, _ => None
- end
- | MatrixVectorAdd n m _ l r =>
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some ld, Some rd =>
- Some (build_matrix (fun i j => (mnth ld i j) + (vnth rd i)))
- | _, _ => None
- end
- | MatrixMult n m p _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some ld, Some rd =>
- Some (build_matrix (fun i k => vsum (build_vector (fun j => (mnth le i j)*(mnth rd j k) + (mnth ld i j)*(mnth re j k)))))
- | _, _ => None
- end
- | VectorPlus n _ l r =>
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some l', Some r' => Some (vmap2 Fplus l' r')
- | _, _ => None
- end
- | VectorMinus n _ l r =>
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some l', Some r' => Some (vmap2 Fminus l' r')
- | _, _ => None
- end
- | MatrixPlus n m _ l r =>
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some l', Some r' => Some (mmap2 Fplus l' r')
- | _, _ => None
- end
- | MatrixMinus n m _ l r =>
- match df_eval_tree_deriv σ l v, df_eval_tree_deriv σ r v with
- | Some l', Some r' => Some (mmap2 Fminus l' r')
- | _, _ => None
- end
- | VectorApply n _ x s r =>
- let re := get_annotation r in
- match df_eval_tree_deriv σ r v with
- | Some rd =>
- vectoro_to_ovector
- (build_vector (fun i =>
- let xv := (x, DTfloat):var_type in
- match df_eval_deriv (cons (mk_env_entry xv (vnth re i)) nil ) s xv with
- | Some sd => Some ((vnth rd i) * sd)
- | _ => None
- end))
- | _ => None
- end
- | MatrixApply n m _ x s r =>
- let re := get_annotation r in
- match df_eval_tree_deriv σ r v with
- | Some rd =>
- matrixo_to_omatrix
- (build_matrix (fun i j =>
- let xv := (x, DTfloat):var_type in
- match df_eval_deriv (cons (mk_env_entry xv (mnth re i j)) nil ) s xv with
- | Some sd => Some ((mnth rd i j) * sd)
- | _ => None
- end))
- | _ => None
- end
- | VLossfun n _ v1 v2 s l r =>
- let le := get_annotation l in
- match df_eval_tree_deriv σ l v with
- | Some ld =>
- match (vectoro_to_ovector
- (build_vector (fun i =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- match df_eval_deriv (cons (mk_env_entry xv1 (vnth le i))
- (cons (mk_env_entry xv2 (vnth r i)) nil)) s xv1 with
- | Some sd => Some ((vnth ld i) * sd)
- | _ => None
- end))) with
- | Some vv => Some (vsum vv)
- | _ => None
- end
- | _ => None
- end
- | MLossfun n m _ v1 v2 s l r =>
- let le := get_annotation l in
- match df_eval_tree_deriv σ l v with
- | Some ld =>
- match (matrixo_to_omatrix
- (build_matrix (fun i j =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- match df_eval_deriv (cons (mk_env_entry xv1 (mnth le i j))
- (cons (mk_env_entry xv2 (mnth r i j)) nil )) s xv1 with
- | Some sd => Some ((mnth ld i j) * sd)
- | _ => None
- end))) with
- | Some vv => Some ((msum vv) / (FfromZ (Z.of_nat m)))
- | _ => None
- end
- | _ => None
- end
- end).
-
- Fixpoint df_eval_tree_deriv_genvar {T} (σ:df_env) (df:DefinedFunction EvalAnn T) (v:df_env) : option (definition_function_types_interp T)
- := (match df with
- | Number _ _ => Some 0
- | Constant t _ x => Some
- match t return definition_function_types_interp t with
- | DTfloat => 0
- | DTVector n => ConstVector n 0
- | DTMatrix m n => ConstMatrix m n 0
- end
- | DVector n _ dfs => vectoro_to_ovector (vmap (fun e => df_eval_tree_deriv_genvar σ e v) dfs)
- | DMatrix n m _ df => matrixo_to_omatrix (mmap (fun e => df_eval_tree_deriv_genvar σ e v) df)
- | Var x _ => Some (
- match vartlookup v x with
- | Some val => val
- | _ =>
- match (snd x) with
- | DTfloat => 0
- | DTVector n => ConstVector n 0
- | DTMatrix m n => ConstMatrix m n 0
- end
- end)
- | Plus _ l r =>
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some le, Some lr => Some (le + lr)
- | _, _ => None
- end
- | Minus _ l r =>
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some le, Some lr => Some (le - lr)
- | _, _ => None
- end
- | Times _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some ld, Some rd =>
- Some (le * rd +
- (ld * re))
- | _, _ => None
- end
- | Divide _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some ld, Some rd =>
- Some ((ld / re) - ((le * rd) / (re * re)))
- | _, _ => None
- end
- | Square _ e =>
- let ee := get_annotation e in
- match df_eval_tree_deriv_genvar σ e v with
- | Some ed => Some (2 * ee * ed)
- | _ => None
- end
- | Exp _ e =>
- let ee := get_annotation e in
- match df_eval_tree_deriv_genvar σ e v with
- | Some ed => Some (ed * Fexp ee)
- | _ => None
- end
- | Log _ e =>
- let ee := get_annotation e in
- match df_eval_tree_deriv_genvar σ e v with
- | Some ed => Some (ed / ee)
- | _ => None
- end
- | Abs _ e =>
- let ee := get_annotation e in
- match df_eval_tree_deriv_genvar σ e v with
- | Some ed => Some (ed * (sign ee))
- | _ => None
- end
- | Sign _ e =>
- match df_eval_tree_deriv_genvar σ e v with
- | Some _ => Some 0
- | None => None
- end
- | PSign _ e =>
- match df_eval_tree_deriv_genvar σ e v with
- | Some _ => Some 0
- | None => None
- end
- | Max _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- if le <= re then df_eval_tree_deriv_genvar σ r v else df_eval_tree_deriv_genvar σ l v
- | VectorElem n _ l i =>
- match (df_eval_tree_deriv_genvar σ l v) with
- | Some l' => Some (vnth l' i)
- | _ => None
- end
- | MatrixElem m n _ l i j =>
- match (df_eval_tree_deriv_genvar σ l v) with
- | Some l' => Some (mnth l' i j)
- | _ => None
- end
- | VectorDot n _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some ld, Some rd =>
- Some (vsum (vmap4 (fun le1 rd1 ld1 re1 => le1 * rd1 + ld1 * re1) le rd ld re))
- | _, _ => None
- end
- | VectorSum n _ l =>
- match df_eval_tree_deriv_genvar σ l v with
- | Some ld =>
- Some (vsum ld)
- | _ => None
- end
- | MatrixSum n m _ l =>
- match df_eval_tree_deriv_genvar σ l v with
- | Some ld =>
- Some (msum ld)
- | _ => None
- end
- | VectorScalMult n _ x r =>
- let '(xe,re) := (get_annotation x, get_annotation r) in
- match df_eval_tree_deriv_genvar σ x v, df_eval_tree_deriv_genvar σ r v with
- | Some xd, Some rd =>
- Some (vmap2 (fun rd1 re1 => xe * rd1 + xd * re1) rd re)
- | _, _ => None
- end
- | MatrixScalMult n m _ x r =>
- let '(xe,re) := (get_annotation x, get_annotation r) in
- match df_eval_tree_deriv_genvar σ x v, df_eval_tree_deriv_genvar σ r v with
- | Some xd, Some rd =>
- Some (mmap2 (fun rd1 re1 => xe * rd1 + xd * re1) rd re)
- | _, _ => None
- end
- | MatrixVectorMult n m _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some ld, Some rd =>
- Some (build_vector (fun i => vsum (vmap4 (fun lei rde ldi ree => lei * rde + ldi * ree)
- (vnth le i) rd (vnth ld i) re)))
- | _, _ => None
- end
- | MatrixVectorAdd n m _ l r =>
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some ld, Some rd =>
- Some (build_matrix (fun i j => (mnth ld i j) + (vnth rd i)))
- | _, _ => None
- end
- | MatrixMult n m p _ l r =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some ld, Some rd =>
- Some (build_matrix (fun i k => vsum (build_vector (fun j => (mnth le i j)*(mnth rd j k) + (mnth ld i j)*(mnth re j k)))))
- | _, _ => None
- end
- | VectorPlus n _ l r =>
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some l', Some r' => Some (vmap2 Fplus l' r')
- | _, _ => None
- end
- | VectorMinus n _ l r =>
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some l', Some r' => Some (vmap2 Fminus l' r')
- | _, _ => None
- end
- | MatrixPlus n m _ l r =>
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some l', Some r' => Some (mmap2 Fplus l' r')
- | _, _ => None
- end
- | MatrixMinus n m _ l r =>
- match df_eval_tree_deriv_genvar σ l v, df_eval_tree_deriv_genvar σ r v with
- | Some l', Some r' => Some (mmap2 Fminus l' r')
- | _, _ => None
- end
- | VectorApply n _ x s r =>
- let re := get_annotation r in
- match df_eval_tree_deriv_genvar σ r v with
- | Some rd =>
- vectoro_to_ovector
- (build_vector (fun i =>
- let xv := (x, DTfloat):var_type in
- match df_eval_deriv_genvar (cons (mk_env_entry xv (vnth re i)) nil ) s
- (mk_genvar_env x) with
- | Some sd => Some ((vnth rd i) * sd)
- | _ => None
- end))
- | _ => None
- end
- | MatrixApply n m _ x s r =>
- let re := get_annotation r in
- match df_eval_tree_deriv_genvar σ r v with
- | Some rd =>
- matrixo_to_omatrix
- (build_matrix (fun i j =>
- let xv := (x, DTfloat):var_type in
- match df_eval_deriv_genvar (cons (mk_env_entry xv (mnth re i j)) nil) s
- (mk_genvar_env x) with
- | Some sd => Some ((mnth rd i j) * sd)
- | _ => None
- end))
- | _ => None
- end
- | VLossfun n _ v1 v2 s l r =>
- let le := get_annotation l in
- match df_eval_tree_deriv_genvar σ l v with
- | Some ld =>
- match (vectoro_to_ovector
- (build_vector (fun i =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- match df_eval_deriv_genvar (cons (mk_env_entry xv1 (vnth le i))
- (cons (mk_env_entry xv2 (vnth r i)) nil )) s
- (mk_genvar_env v1) with
- | Some sd => Some ((vnth ld i) * sd)
- | _ => None
- end))) with
- | Some vv => Some (vsum vv)
- | _ => None
- end
- | _ => None
- end
- | MLossfun n m _ v1 v2 s l r =>
- let le := get_annotation l in
- match df_eval_tree_deriv_genvar σ l v with
- | Some ld =>
- match (matrixo_to_omatrix
- (build_matrix (fun i j =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- match df_eval_deriv_genvar (cons (mk_env_entry xv1 (mnth le i j))
- (cons (mk_env_entry xv2 (mnth r i j)) nil)) s
- (mk_genvar_env v1) with
- | Some sd => Some ((mnth ld i j) * sd)
- | _ => None
- end))) with
- | Some vv => Some ((msum vv) / (FfromZ (Z.of_nat m)))
- | _ => None
- end
- | _ => None
- end
- end).
-
- Definition vector_env_iter {n} {A} (f: A -> df_env -> option df_env)
- (env: df_env) (v : vector A n) : option df_env :=
- Vector.fold_right (fun a oenv => match oenv with
- | Some env => f a env
- | _ => None
- end)
- v (Some env).
-
- Fixpoint list_env_iter {A} (f: A -> df_env -> option df_env)
- (oenv:option df_env) (l: list A) : option df_env :=
- match oenv, l with
- | Some env, x :: l' => list_env_iter f (f x env) l'
- | _, _ => oenv
- end.
-
- Lemma list_env_iter_none {A} (f: A -> df_env -> option df_env) (l: list A) :
- list_env_iter f None l = None.
- Proof.
- induction l.
- now simpl.
- now simpl.
- Qed.
-
- Lemma list_env_iter_env_not_none {A} (f: A -> df_env -> option df_env)
- (oenv : option df_env) (l: list A):
- list_env_iter f oenv l <> None -> oenv <> None.
- Proof.
- intros.
- destruct oenv.
- + discriminate.
- + rewrite list_env_iter_none in H.
- tauto.
- Qed.
-
- Lemma list_env_iter_app {A} (f: A -> df_env -> option df_env)
- (oenv:option df_env) (l1 l2: list A) :
- list_env_iter f oenv (l1++l2) =
- list_env_iter f (list_env_iter f oenv l1) l2.
- Proof.
- revert l2 oenv.
- induction l1; intros l2 oenv; simpl.
- - now destruct oenv.
- - destruct oenv.
- + auto.
- + now rewrite list_env_iter_none.
- Qed.
-
-
- Lemma list_env_iter_ext {A} f1 f2 oenv (l:list A) :
- (forall x a, In x l -> f1 x a = f2 x a) ->
- list_env_iter f1 oenv l = list_env_iter f2 oenv l.
- Proof.
- intros fa.
- revert oenv.
- induction l; intros oenv; intros; simpl
- ; match_destr.
- rewrite fa; simpl; intuition.
- Qed.
-
-
- Definition two_vector_env_iter {n} {A B} (f: A -> B -> df_env -> option df_env)
- (env: df_env) (v: vector A n) (w: vector B n) : option df_env :=
- vector_env_iter (fun '(a,b) env => f a b env) env
- (vector_zip v w).
-
- Fixpoint two_vector_env_iter_alt2 {n1 n2} {A B} (f: A -> B -> df_env -> option df_env)
- (oenv: option df_env) (v: vector A n1) (w: vector B n2) : option df_env :=
- match oenv,v,w with
- | Some env, Vector.cons vx _ v', Vector.cons wx _ w' => two_vector_env_iter_alt2 f (f vx wx env) v' w'
- | oenv',_,_ => oenv'
- end.
-
- Definition two_vector_env_iter_alt {n} {A B} (f: A -> B -> df_env -> option df_env)
- (env: df_env) (v: vector A n) (w: vector B n) : option df_env :=
- list_env_iter (fun '(a,b) env => f a b env) (Some env) (combine (Vector.to_list v) (Vector.to_list w)).
-
- Definition matrix_env_iter {m n} {A} (f: A -> df_env -> option df_env)
- (env: option df_env) (mat : matrix A m n) : option df_env :=
- Vector.fold_right
- (fun vec oenv =>
- Vector.fold_right (fun a oenv => match oenv with
- | Some env => f a env
- | _ => None
- end) vec oenv
- ) mat env.
-
- Definition two_matrix_env_iter {n m} {A B} (f: A -> B -> df_env -> option df_env)
- (env: option df_env) (v: matrix A n m) (w: matrix B n m) : option df_env :=
- let vw := matrix_zip v w in
- matrix_env_iter (fun '(a,b) e => f a b e) env vw.
-
- Definition two_matrix_env_iter_alt {n m} {A B} (f: A -> B -> df_env -> option df_env)
- (env: df_env) (v: matrix A n m) (w: matrix B n m) : option df_env :=
- list_env_iter (fun i env => list_env_iter (fun j env => f (mnth v i j) (mnth w i j) env)
- (Some env) (bounded_seq0 m))
- (Some env) (bounded_seq0 n).
-
-
- Program Definition addvar (x : var_type) (grad_env:df_env) :=
- (match snd x as y return snd x = y ->
- definition_function_types_interp y ->
- definition_function_types_interp y with
- | DTfloat => fun pf grad => match vartlookup grad_env x with
- | Some val => grad + ((coerce _ val):float)
- | _ => grad
- end
- | DTVector n => fun pf grad => match vartlookup grad_env x with
- | Some val => build_vector (fun i => (vnth grad i) + (vnth ((coerce _ val):vector float n) i))
- | _ => grad
- end
- | DTMatrix m n => fun pf grad => match vartlookup grad_env x with
- | Some val => build_matrix (fun i j => (mnth ((coerce _ val):matrix float m n) i j) + (mnth grad i j))
- | _ => grad
- end
- end) (eq_refl _).
- Next Obligation.
- rewrite pf; reflexivity.
- Defined.
- Next Obligation.
- rewrite pf; reflexivity.
- Defined.
- Next Obligation.
- rewrite pf; reflexivity.
- Defined.
-
- Definition gradenv_init1 (v : var_type) : env_entry_type :=
- mk_env_entry v
- (match snd v as y return definition_function_types_interp y with
- | DTfloat => 0
- | DTVector n => ConstVector n 0
- | DTMatrix n m => ConstMatrix n m 0
- end).
-
- Definition gradenv_init (dvars : list var_type) : df_env :=
- map gradenv_init1 dvars.
-
- Fixpoint df_eval_backprop_deriv {T Ann} (σ:df_env) (df:DefinedFunction Ann T) (grad_env:df_env) {struct df} : definition_function_types_interp T -> option df_env
- := match df with
- | Number _ _ => fun grad => Some grad_env
- | Constant _ _ _ => fun grad => Some grad_env
- | DVector n _ dfs =>
- fun grad =>
- ((fix tv_env_iter {n1 n2}
- (oenv: option df_env) (v: vector _ n1) (w: vector _ n2) : option df_env :=
- match oenv,v,w with
- | Some env, Vector.cons vx _ v', Vector.cons wx _ w' =>
- tv_env_iter (df_eval_backprop_deriv σ vx env wx) v' w'
- | oenv',_,_ => oenv'
- end)
- n n
- (Some grad_env) dfs grad )
- | DMatrix n m _ dfs => fun grad =>
- ((fix tm_env_iter {n1 m1 n2 m2}
- (oenv: option df_env) (v: matrix _ n1 m1) (w: matrix _ n2 m2) : option df_env :=
- match oenv,v,w with
- | Some env, Vector.cons vr _ v', Vector.cons wr _ w' =>
- let nenv :=
- ((fix tv_env_iter {n1 n2}
- (oenv: option df_env) (v: vector _ n1) (w: vector _ n2) : option df_env :=
- match oenv,v,w with
- | Some env, Vector.cons vx _ v', Vector.cons wx _ w' =>
- tv_env_iter (df_eval_backprop_deriv σ vx env wx) v' w'
- | oenv',_,_ => oenv'
- end)
- m1 m2
- (Some env) vr wr ) in
- tm_env_iter nenv v' w'
- | oenv',_,_ => oenv'
- end)
- n m n m
- (Some grad_env) dfs grad )
- | Var x _ => fun grad =>
- if vartlookup grad_env x then
- Some (vart_update grad_env x (addvar x grad_env grad))
- else Some grad_env
- | Plus _ l r => fun grad =>
- match df_eval_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_backprop_deriv σ r grad_env' grad
- | _ => None
- end
- | Minus _ l r => fun grad =>
- match df_eval_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (-grad)
- | _ => None
- end
- | Times _ l r => fun grad =>
- match df_eval σ l, df_eval σ r with
- | Some le, Some re =>
- match df_eval_backprop_deriv σ l grad_env (re * grad) with
- | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (le * grad)
- | _ => None
- end
- | _, _ => None
- end
- | Divide _ l r => fun grad =>
- match df_eval σ l, df_eval σ r with
- | Some le, Some re =>
- match df_eval_backprop_deriv σ l grad_env (grad / re) with
- | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (- le / (re * re) * grad)
- | _ => None
- end
- | _, _ => None
- end
- | Square _ e => fun grad =>
- match df_eval σ e with
- | Some ee => df_eval_backprop_deriv σ e grad_env (2 * ee * grad)
- | _ => None
- end
- | Exp _ e => fun grad =>
- match df_eval σ e with
- | Some ee => df_eval_backprop_deriv σ e grad_env (grad * Fexp ee)
- | _ => None
- end
- | Log _ e => fun grad =>
- match df_eval σ e with
- | Some ee => df_eval_backprop_deriv σ e grad_env (grad / ee)
- | _ => None
- end
- | Abs _ e => fun grad =>
- match df_eval σ e with
- | Some ee => df_eval_backprop_deriv σ e grad_env (grad * (sign ee))
- | _ => None
- end
- | Sign _ e => fun grad => df_eval_backprop_deriv σ e grad_env 0
- | PSign _ e => fun grad => df_eval_backprop_deriv σ e grad_env 0
- | Max _ l r => fun grad =>
- match df_eval σ l,
- df_eval σ r with
- | Some le, Some re =>
- if le <= re then
- (df_eval_backprop_deriv σ r grad_env grad) else
- (df_eval_backprop_deriv σ l grad_env grad)
- | _, _ => None
- end
- | VectorElem n _ l i => fun grad =>
- let grad' := fun k => if proj1_sig k == proj1_sig i then grad else 0 in
- df_eval_backprop_deriv σ l grad_env (build_vector grad')
- | MatrixElem m n _ l i j => fun grad =>
- let grad' := fun k1 k2 =>
- if (proj1_sig k1 == proj1_sig i) then
- if (proj1_sig k2 == proj1_sig j) then grad else 0
- else 0 in
- df_eval_backprop_deriv σ l grad_env (build_matrix grad')
- | VectorDot n _ l r => fun grad =>
- match df_eval σ l, df_eval σ r with
- | Some le, Some re =>
- match df_eval_backprop_deriv σ l grad_env (vmap (fun rv => rv*grad) re) with
- | Some grad_env' =>
- df_eval_backprop_deriv σ r grad_env' (vmap (fun lv => lv*grad) le)
- | _ => None
- end
- | _, _ => None
- end
- | VectorSum n _ l => fun grad =>
- df_eval_backprop_deriv σ l grad_env (ConstVector n grad)
- | MatrixSum n m _ l => fun grad =>
- df_eval_backprop_deriv σ l grad_env (ConstMatrix n m grad)
- | VectorScalMult n _ x r => fun grad =>
- match df_eval σ x, df_eval σ r with
- | Some xe, Some re =>
- match df_eval_backprop_deriv σ x grad_env (vsum (vmap2 Fmult re grad)) with
- | Some grad_env' =>
- df_eval_backprop_deriv σ r grad_env' (vmap (fun e => xe * e) grad)
- | _ => None
- end
- | _, _ => None
- end
- | MatrixScalMult n m _ x r => fun grad =>
- match df_eval σ x, df_eval σ r with
- | Some xe, Some re =>
- match df_eval_backprop_deriv σ x grad_env (msum (mmap2 Fmult re grad)) with
- | Some grad_env' =>
- df_eval_backprop_deriv σ r grad_env' (mmap (fun e => e * xe) grad)
- | _ => None
- end
- | _, _ => None
- end
- | MatrixVectorMult n m _ l r => fun grad =>
- match df_eval σ l, df_eval σ r with
- | Some le, Some re =>
- match df_eval_backprop_deriv σ l grad_env (build_matrix (fun i j => (vnth grad i) * (vnth re j))) with
-
-
- | Some grad_env' =>
- df_eval_backprop_deriv σ r grad_env' (matrix_vector_mult (transpose le) grad)
- | _ => None
- end
- | _, _ => None
- end
- | MatrixVectorAdd n m _ l r => fun grad =>
- match df_eval_backprop_deriv σ l grad_env grad with
- | Some grad_env' =>
- match list_env_iter
- (fun i env => df_eval_backprop_deriv σ r env (vnth (transpose grad) i))
- (Some grad_env') (bounded_seq0 m) with
- | Some grad_env'' => Some grad_env''
- | _ => None
- end
- | _ => None
- end
- | MatrixMult n m p _ l r => fun grad =>
- match df_eval σ l, df_eval σ r with
- | Some le, Some re =>
- match df_eval_backprop_deriv σ l grad_env (matrix_mult grad (transpose re)) with
-
-
- | Some grad_env' =>
- df_eval_backprop_deriv σ r grad_env' (matrix_mult (transpose le) grad)
- | _ => None
- end
- | _, _ => None
- end
- | VectorPlus n _ l r => fun grad =>
- match df_eval_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_backprop_deriv σ r grad_env' grad
- | _ => None
- end
- | VectorMinus n _ l r => fun grad =>
- match df_eval_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (vmap Fopp grad)
- | _ => None
- end
- | MatrixPlus n m _ l r => fun grad =>
- match df_eval_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_backprop_deriv σ r grad_env' grad
- | _ => None
- end
- | MatrixMinus n m _ l r => fun grad =>
- match df_eval_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_backprop_deriv σ r grad_env' (mmap Fopp grad)
- | _ => None
- end
- | VectorApply n _ x s r => fun grad =>
- match df_eval σ r with
- | Some re =>
- let xv := (x, DTfloat):var_type in
- let s' := df_deriv s xv in
- let ograd :=
- vmap (fun '(rei, g) =>
-
- match df_eval (cons (mk_env_entry xv rei) nil) s' with
- | Some se => Some (g * se)
- | _ => None
- end)
- (vcombine re grad) in
- match vectoro_to_ovector ograd with
- | Some grad' => df_eval_backprop_deriv σ r grad_env grad'
- | _ => None
- end
- | _ => None
- end
- | MatrixApply n m _ x s r => fun grad =>
- match df_eval σ r with
- | Some re =>
- let xv := (x, DTfloat):var_type in
- let s' := df_deriv s xv in
- let ograd :=
- mmap (fun '(rei, g) =>
- match df_eval (cons (mk_env_entry xv rei) nil) s' with
- | Some se => Some (g * se)
- | _ => None
- end)
- (mcombine re grad) in
- match matrixo_to_omatrix ograd with
- | Some grad' => df_eval_backprop_deriv σ r grad_env grad'
- | _ => None
- end
- | _ => None
- end
- | VLossfun n _ v1 v2 s l re => fun grad =>
- match df_eval σ l with
- | Some le =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- let s' := df_deriv s xv1 in
- let ograd :=
- vmap (fun '(lei, rei) =>
- let senv := cons (mk_env_entry xv1 lei)
- (cons (mk_env_entry xv2 rei) nil) in
- match df_eval senv s' with
- | Some se => Some (grad * se)
- | _ => None
- end)
- (vcombine le re) in
- match vectoro_to_ovector ograd with
- | Some grad' => df_eval_backprop_deriv σ l grad_env grad'
- | _ => None
- end
- | _ => None
- end
- | MLossfun n m _ v1 v2 s l re => fun grad =>
- match df_eval σ l with
- | Some le =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- let s' := df_deriv s xv1 in
- let ograd :=
- mmap (fun '(lei, rei) =>
- let senv := cons (mk_env_entry xv1 lei)
- (cons (mk_env_entry xv2 rei) nil) in
- match df_eval senv s' with
- | Some se => Some ((grad * se)/(FfromZ (Z.of_nat m)))
- | _ => None
- end)
- (mcombine le re) in
- match matrixo_to_omatrix ograd with
- | Some grad' => df_eval_backprop_deriv σ l grad_env grad'
- | _ => None
- end
- | _ => None
- end
- end.
-
- Definition lifted_type (B:Type) T
- := match T with
- | DTfloat => B
- | DTVector n => vector B n
- | DTMatrix m n => matrix B m n
- end.
-
- Fixpoint df_eval_tree_backprop_deriv {T} (σ:df_env) (df:DefinedFunction EvalAnn T) (grad_env:df_env) {struct df} : definition_function_types_interp T -> option df_env
- := match df with
- | Number _ _ => fun grad => Some grad_env
- | Constant _ _ _ => fun grad => Some grad_env
- | DVector n _ dfs => fun grad =>
- ((fix tv_env_iter {n1 n2}
- (oenv: option df_env) (v: vector _ n1) (w: vector _ n2) : option df_env :=
- match oenv,v,w with
- | Some env, Vector.cons vx _ v', Vector.cons wx _ w' =>
- tv_env_iter (df_eval_tree_backprop_deriv σ vx env wx) v' w'
- | oenv',_,_ => oenv'
- end)
- n n
- (Some grad_env) dfs grad )
- | DMatrix n m _ dfs => fun grad =>
- ((fix tm_env_iter {n1 m1 n2 m2}
- (oenv: option df_env) (v: matrix _ n1 m1) (w: matrix _ n2 m2) : option df_env :=
- match oenv,v,w with
- | Some env, Vector.cons vr _ v', Vector.cons wr _ w' =>
- let nenv :=
- ((fix tv_env_iter {n1 n2}
- (oenv: option df_env) (v: vector _ n1) (w: vector _ n2) : option df_env :=
- match oenv,v,w with
- | Some env, Vector.cons vx _ v', Vector.cons wx _ w' =>
- tv_env_iter (df_eval_backprop_deriv σ vx env wx) v' w'
- | oenv',_,_ => oenv'
- end)
- m1 m2
- (Some env) vr wr ) in
- tm_env_iter nenv v' w'
- | oenv',_,_ => oenv'
- end)
- n m n m
- (Some grad_env) dfs grad )
- | Var x _ => fun grad =>
- if vartlookup grad_env x then
- Some (vart_update grad_env x (addvar x grad_env grad))
- else Some grad_env
- | Plus _ l r => fun grad =>
- match df_eval_tree_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' grad
- | _ => None
- end
- | Minus _ l r => fun grad =>
- match df_eval_tree_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (-grad)
- | _ => None
- end
- | Times _ l r => fun grad =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_backprop_deriv σ l grad_env (re * grad) with
- | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (le * grad)
- | _ => None
- end
- | Divide _ l r => fun grad =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_backprop_deriv σ l grad_env (grad / re) with
- | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (- le / (re * re) * grad)
- | _ => None
- end
- | Square _ e => fun grad =>
- let ee := get_annotation e in
- df_eval_tree_backprop_deriv σ e grad_env (2 * ee * grad)
- | Exp _ e => fun grad =>
- let ee := get_annotation e in
- df_eval_tree_backprop_deriv σ e grad_env (grad * Fexp ee)
- | Log _ e => fun grad =>
- let ee := get_annotation e in
- df_eval_tree_backprop_deriv σ e grad_env (grad / ee)
- | Abs _ e => fun grad =>
- let ee := get_annotation e in
- df_eval_tree_backprop_deriv σ e grad_env (grad * (sign ee))
- | Sign _ e => fun grad => df_eval_tree_backprop_deriv σ e grad_env 0
- | PSign _ e => fun grad => df_eval_tree_backprop_deriv σ e grad_env 0
- | Max _ l r => fun grad =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- if le <= re then
- (df_eval_tree_backprop_deriv σ r grad_env grad) else
- (df_eval_tree_backprop_deriv σ l grad_env grad)
- | VectorElem n _ l i => fun grad =>
- let grad' := build_vector (fun k => if proj1_sig k == proj1_sig i then grad else 0) in
- df_eval_tree_backprop_deriv σ l grad_env grad'
- | MatrixElem m n _ l i j => fun grad =>
- let grad' := build_matrix (fun k1 k2 =>
- if (proj1_sig k1 == proj1_sig i) then
- if (proj1_sig k2 == proj1_sig j) then grad else 0
- else 0) in
- df_eval_tree_backprop_deriv σ l grad_env grad'
- | VectorDot n _ l r => fun grad =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_backprop_deriv σ l grad_env (vmap (fun rv => rv*grad) re) with
- | Some grad_env' =>
- df_eval_tree_backprop_deriv σ r grad_env' (vmap (fun lv => lv*grad) le)
- | _ => None
- end
- | VectorSum n _ l => fun grad =>
- df_eval_tree_backprop_deriv σ l grad_env (ConstVector n grad)
- | MatrixSum n m _ l => fun grad =>
- df_eval_tree_backprop_deriv σ l grad_env (ConstMatrix n m grad)
- | VectorScalMult n _ x r => fun grad =>
- let '(xe,re) := (get_annotation x, get_annotation r) in
- match df_eval_tree_backprop_deriv σ x grad_env (vsum (vmap2 Fmult re grad)) with
- | Some grad_env' =>
- df_eval_tree_backprop_deriv σ r grad_env' (vmap (fun ge => xe * ge) grad)
- | _ => None
- end
- | MatrixScalMult n m _ x r => fun grad =>
- let '(xe,re) := (get_annotation x, get_annotation r) in
- match df_eval_tree_backprop_deriv σ x grad_env (msum (mmap2 Fmult re grad)) with
- | Some grad_env' =>
- df_eval_tree_backprop_deriv σ r grad_env' (mmap (fun ge => ge*xe) grad)
- | _ => None
- end
- | MatrixVectorMult n m _ l r => fun grad =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_backprop_deriv σ l grad_env (build_matrix (fun i j => (vnth grad i) * (vnth re j))) with
- | Some grad_env' =>
- df_eval_tree_backprop_deriv σ r grad_env' (matrix_vector_mult (transpose le) grad)
- | _ => None
- end
- | MatrixVectorAdd n m _ l r => fun grad =>
- match df_eval_tree_backprop_deriv σ l grad_env grad with
- | Some grad_env' =>
- match list_env_iter
- (fun i env => df_eval_tree_backprop_deriv σ r env (vnth (transpose grad) i))
- (Some grad_env') (bounded_seq0 m) with
- | Some grad_env'' => Some grad_env''
- | _ => None
- end
- | _ => None
- end
- | MatrixMult n m p _ l r => fun grad =>
- let '(le,re) := (get_annotation l, get_annotation r) in
- match df_eval_tree_backprop_deriv σ l grad_env (matrix_mult grad (transpose re)) with
- | Some grad_env' =>
- df_eval_tree_backprop_deriv σ r grad_env' (matrix_mult (transpose le) grad)
- | _ => None
- end
- | VectorPlus n _ l r => fun grad =>
- match df_eval_tree_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' grad
- | _ => None
- end
- | VectorMinus n _ l r => fun grad =>
- match df_eval_tree_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (vmap Fopp grad)
- | _ => None
- end
- | MatrixPlus n m _ l r => fun grad =>
- match df_eval_tree_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' grad
- | _ => None
- end
- | MatrixMinus n m _ l r => fun grad =>
- match df_eval_tree_backprop_deriv σ l grad_env grad with
- | Some grad_env' => df_eval_tree_backprop_deriv σ r grad_env' (mmap Fopp grad)
- | _ => None
- end
- | VectorApply n _ x s r => fun grad =>
- let re := get_annotation r in
- let xv := (x, DTfloat):var_type in
- let s' := df_deriv s xv in
- let ograd :=
- vmap (fun '(rei, g) =>
- match df_eval (cons (mk_env_entry xv rei) nil) s' with
- | Some se => Some (g * se)
- | _ => None
- end)
- (vector_zip re grad) in
- match vectoro_to_ovector ograd with
- | Some grad' => df_eval_tree_backprop_deriv σ r grad_env grad'
- | _ => None
- end
- | MatrixApply n m _ x s r => fun grad =>
- let re := get_annotation r in
- let xv := (x, DTfloat):var_type in
- let s' := df_deriv s xv in
- let ograd :=
- mmap (fun '(rei, g) =>
- match df_eval (cons (mk_env_entry xv rei) nil) s' with
- | Some se => Some (g * se)
- | _ => None
- end)
- (matrix_zip re grad) in
- match matrixo_to_omatrix ograd with
- | Some grad' => df_eval_tree_backprop_deriv σ r grad_env grad'
- | _ => None
- end
- | VLossfun n _ v1 v2 s l re => fun grad =>
- let le := get_annotation l in
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- let s' := df_deriv s xv1 in
- let ograd :=
- vmap (fun '(lei, rei) =>
- let senv := cons (mk_env_entry xv1 lei)
- (cons (mk_env_entry xv2 rei) nil) in
- match df_eval senv s' with
- | Some se => Some (grad * se)
- | _ => None
- end)
- (vector_zip le re) in
- match vectoro_to_ovector ograd with
- | Some grad' => df_eval_tree_backprop_deriv σ l grad_env grad'
- | _ => None
- end
- | MLossfun n m _ v1 v2 s l re => fun grad =>
- let le := get_annotation l in
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- let s' := df_deriv s xv1 in
- let ograd :=
- mmap (fun '(lei, rei) =>
- let senv := cons (mk_env_entry xv1 lei)
- (cons (mk_env_entry xv2 rei) nil) in
- match df_eval senv s' with
- | Some se => Some ((grad * se) / (FfromZ (Z.of_nat m)))
- | _ => None
- end)
- (matrix_zip le re) in
- match matrixo_to_omatrix ograd with
- | Some grad' => df_eval_tree_backprop_deriv σ l grad_env grad'
- | _ => None
- end
- end.
-
- Definition o_df_env_to_df_env (oenv : option df_env) : df_env :=
- match oenv with
- | Some env => env
- | _ => nil
- end.
-
-
- Definition backprop_lookup (oenv:option df_env) (a:var_type) :
- option (definition_function_types_interp (snd a)) :=
- match oenv with
- | Some env =>
- match vartlookup env a with
- | Some val => Some val
- | _ => None
- end
- | _ => None
- end.
-
- Definition is_scalar_df_type (dft:definition_function_types) : Prop
- := match dft with
- | DTfloat => True
- | _ => False
- end.
-
- Fixpoint is_scalar_function {Ann} {T} (df:DefinedFunction Ann T) : Prop
- := match df with
- | Number _ _ => True
- | Constant t _ _ => is_scalar_df_type t
- | Var v _ => is_scalar_df_type (snd v)
- | Plus _ l r => is_scalar_function l /\ is_scalar_function r
- | Minus _ l r => is_scalar_function l /\ is_scalar_function r
- | Times _ l r => is_scalar_function l /\ is_scalar_function r
- | Divide _ l r => is_scalar_function l /\ is_scalar_function r
- | Square _ e => is_scalar_function e
- | Exp _ e => is_scalar_function e
- | Log _ e => is_scalar_function e
- | Abs _ e => is_scalar_function e
- | Sign _ e => is_scalar_function e
- | PSign _ e => is_scalar_function e
- | Max _ l r => is_scalar_function l /\ is_scalar_function r
- | _ => False
- end.
-
- Fixpoint has_scalar_functions {Ann} {T}
- (df:DefinedFunction Ann T) {struct df}: Prop
- := match df with
- | Number _ _ => True
- | Constant _ _ _ => True
- | DVector n _ vec =>
- ((fix vforal {n'} (v:vector (DefinedFunction Ann DTfloat) n') : Prop :=
- match v with
- | Vector.nil => True
- | Vector.cons vx _ v' => (has_scalar_functions vx) /\ (vforal v')
- end)
- n vec)
- | DMatrix n m _ mat =>
- ((fix mforall {n' m'} (mat:matrix (DefinedFunction Ann DTfloat) n' m') : Prop :=
- match mat with
- | Vector.nil => True
- | Vector.cons matr _ mat' =>
- (mforall mat') /\
- ((fix vforal {n'} (v:vector (DefinedFunction Ann DTfloat) n') : Prop :=
- match v with
- | Vector.nil => True
- | Vector.cons vx _ v' => (has_scalar_functions vx) /\ (vforal v')
- end)
- m' matr)
- end)
- n m mat)
- | Var _ _ => True
- | Plus _ l r => (has_scalar_functions l) /\ (has_scalar_functions r)
- | Minus _ l r => (has_scalar_functions l) /\ (has_scalar_functions r)
- | Times _ l r => (has_scalar_functions l) /\ (has_scalar_functions r)
- | Divide _ l r => (has_scalar_functions l) /\ (has_scalar_functions r)
- | Square _ l => has_scalar_functions l
- | Exp _ l => has_scalar_functions l
- | Log _ l => has_scalar_functions l
- | Abs _ l => has_scalar_functions l
- | Sign _ l => has_scalar_functions l
- | PSign _ l => has_scalar_functions l
- | Max _ l r => (has_scalar_functions l) /\ (has_scalar_functions r)
- | VectorDot n _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | VectorSum _ _ l => has_scalar_functions l
- | MatrixSum _ _ _ l => has_scalar_functions l
- | VectorElem _ _ vec i => has_scalar_functions vec
- | MatrixElem _ _ _ mat i j => has_scalar_functions mat
- | MatrixVectorMult _ _ _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | MatrixVectorAdd _ _ _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | MatrixMult _ _ _ _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | VectorPlus _ _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | VectorMinus _ _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | MatrixPlus _ _ _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | MatrixMinus _ _ _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | VectorScalMult _ _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | MatrixScalMult _ _ _ l r => (has_scalar_functions l) /\
- (has_scalar_functions r)
- | VectorApply _ _ _ s l => is_scalar_function s /\ has_scalar_functions l
- | MatrixApply _ _ _ _ s l => is_scalar_function s /\ has_scalar_functions l
- | VLossfun _ _ _ _ s l _ => is_scalar_function s /\ has_scalar_functions l
- | MLossfun _ _ _ _ _ s l _ => is_scalar_function s /\ has_scalar_functions l
- end.
-
- Lemma is_scalar_function_has_scalar_functions {Ann} {T} (df:DefinedFunction Ann T) :
- is_scalar_function df -> has_scalar_functions df.
- Proof.
- induction df; firstorder.
- Qed.
-
- Hint Resolve is_scalar_function_has_scalar_functions.
-
- Definition DefinedFunction_ind_unit_has_scalar_functions
- (P : forall (d : definition_function_types), DefinedFunction UnitAnn d -> Prop)
- (f : forall (ann : UnitAnn DTfloat) (x : float),
- P DTfloat (Number ann x))
- (f0 : forall (t : definition_function_types)
- (ann : UnitAnn t) (x : definition_function_types_interp t), P t (Constant ann x))
- (f1 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (x : vector (DefinedFunction UnitAnn DTfloat) n)
- (f: vforall (P DTfloat) x),
- P (DTVector n) (DVector ann x))
- (f2 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m))
- (x : matrix (DefinedFunction UnitAnn DTfloat) n m)
- (f: mforall (P DTfloat) x),
- P (DTMatrix n m) (DMatrix ann x))
- (f3 : forall (v : var_type) (ann : UnitAnn (snd v)),
- P (snd v) (Var v ann))
- (f4 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Plus ann l r))
- (f5 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Minus ann l r))
- (f6 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Times ann l r))
- (f7 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Divide ann l r))
- (f8 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Square ann e))
- (f9 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Exp ann e))
- (f10 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Log ann e))
- (f11 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Abs ann e))
- (f12 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (Sign ann e))
- (f13 : forall (ann : UnitAnn DTfloat)
- (e : DefinedFunction UnitAnn DTfloat), P DTfloat e -> P DTfloat (PSign ann e))
- (f14 : forall (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn DTfloat),
- P DTfloat l ->
- forall r : DefinedFunction UnitAnn DTfloat, P DTfloat r -> P DTfloat (Max ann l r))
- (f15 : forall (n : nat) (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn (DTVector n)),
- P (DTVector n) l ->
- forall r : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) r -> P DTfloat (VectorDot ann l r))
- (f16 : forall (n : nat) (ann : UnitAnn DTfloat)
- (v : DefinedFunction UnitAnn (DTVector n)), P (DTVector n) v -> P DTfloat (VectorSum ann v))
- (f17 : forall (m n : nat) (ann : UnitAnn DTfloat)
- (v : DefinedFunction UnitAnn (DTMatrix m n)),
- P (DTMatrix m n) v -> P DTfloat (MatrixSum ann v))
- (f18 : forall (n : nat) (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn (DTVector n)),
- P (DTVector n) l ->
- forall i : {x : nat | (x < n)%nat}, P DTfloat (VectorElem ann l i))
- (f19 : forall (m n : nat) (ann : UnitAnn DTfloat)
- (l : DefinedFunction UnitAnn (DTMatrix m n)),
- P (DTMatrix m n) l ->
- forall (i : {x : nat | (x < m)%nat}) (j : {x : nat | (x < n)%nat}),
- P DTfloat (MatrixElem ann l i j))
- (f20 : forall (m n : nat) (ann : UnitAnn (DTVector m))
- (l : DefinedFunction UnitAnn (DTMatrix m n)),
- P (DTMatrix m n) l ->
- forall r : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) r -> P (DTVector m) (MatrixVectorMult ann l r))
- (f21 : forall (m n : nat) (ann : UnitAnn (DTMatrix m n))
- (l : DefinedFunction UnitAnn (DTMatrix m n)),
- P (DTMatrix m n) l ->
- forall r : DefinedFunction UnitAnn (DTVector m),
- P (DTVector m) r -> P (DTMatrix m n) (MatrixVectorAdd ann l r))
- (f22 : forall (m p n : nat) (ann : UnitAnn (DTMatrix m n))
- (l : DefinedFunction UnitAnn (DTMatrix m p)),
- P (DTMatrix m p) l ->
- forall r : DefinedFunction UnitAnn (DTMatrix p n),
- P (DTMatrix p n) r -> P (DTMatrix m n) (MatrixMult ann l r))
- (f23 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (l : DefinedFunction UnitAnn (DTVector n)),
- P (DTVector n) l ->
- forall r : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) r -> P (DTVector n) (VectorPlus ann l r))
- (f24 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (l : DefinedFunction UnitAnn (DTVector n)),
- P (DTVector n) l ->
- forall r : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) r -> P (DTVector n) (VectorMinus ann l r))
- (f25 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m))
- (l : DefinedFunction UnitAnn (DTMatrix n m)),
- P (DTMatrix n m) l ->
- forall r : DefinedFunction UnitAnn (DTMatrix n m),
- P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixPlus ann l r))
- (f26 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m))
- (l : DefinedFunction UnitAnn (DTMatrix n m)),
- P (DTMatrix n m) l ->
- forall r : DefinedFunction UnitAnn (DTMatrix n m),
- P (DTMatrix n m) r -> P (DTMatrix n m) (MatrixMinus ann l r))
- (f27 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (x : DefinedFunction UnitAnn DTfloat),
- P DTfloat x ->
- forall l : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) l -> P (DTVector n) (VectorScalMult ann x l))
- (f28 : forall (n m : nat) (ann : UnitAnn (DTMatrix n m))
- (x : DefinedFunction UnitAnn DTfloat),
- P DTfloat x ->
- forall l : DefinedFunction UnitAnn (DTMatrix n m),
- P (DTMatrix n m) l -> P (DTMatrix n m) (MatrixScalMult ann x l))
- (f29 : forall (n : nat) (ann : UnitAnn (DTVector n))
- (v : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- is_scalar_function s ->
- P DTfloat s ->
- forall l : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) l -> P (DTVector n) (VectorApply ann v s l))
- (f30 : forall (m n : nat) (ann : UnitAnn (DTMatrix m n))
- (v : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- is_scalar_function s ->
- P DTfloat s ->
- forall l : DefinedFunction UnitAnn (DTMatrix m n),
- P (DTMatrix m n) l -> P (DTMatrix m n) (MatrixApply ann v s l))
- (f31 : forall (n : nat) (ann : UnitAnn DTfloat)
- (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- is_scalar_function s ->
- P DTfloat s ->
- forall l : DefinedFunction UnitAnn (DTVector n),
- P (DTVector n) l -> forall r : vector float n, P DTfloat (VLossfun ann v1 v2 s l r))
- (f32 : forall (m n : nat) (ann : UnitAnn DTfloat)
- (v1 v2 : SubVar) (s : DefinedFunction UnitAnn DTfloat),
- is_scalar_function s ->
- P DTfloat s ->
- forall l : DefinedFunction UnitAnn (DTMatrix m n),
- P (DTMatrix m n) l ->
- forall r : matrix float m n, P DTfloat (MLossfun ann v1 v2 s l r))
- : forall (d : definition_function_types)
- (d0 : DefinedFunction UnitAnn d)
- (hs:has_scalar_functions d0), P d d0.
- Proof.
- refine (fix
- F (d : definition_function_types)
- (d0 : DefinedFunction UnitAnn d)
- {struct d0} : has_scalar_functions d0 -> P d d0 :=
- match d0 as d2 in (DefinedFunction _ d1) return has_scalar_functions d2 -> (P d1 d2) with
- | Number ann x => fun hs => f ann x
- | @Constant _ t ann x => fun hs => f0 t ann x
- | @DVector _ n ann x =>
- fun hs =>
- f1 n ann x
- ((fix F1 n (x:vector (DefinedFunction UnitAnn DTfloat) n) : vforall (P DTfloat) x :=
- match x with
- | Vector.nil => Vector.Forall_nil (P DTfloat)
- | Vector.cons h _ tl => Vector.Forall_cons _ _ _ (F DTfloat h _) (F1 _ tl)
- end) n x)
- | @DMatrix _ n m ann x =>
- fun hs =>
- f2 n m ann x
- ((fix F2 n m (x:matrix (DefinedFunction UnitAnn DTfloat) n m) : mforall (P DTfloat) x :=
- match x with
- | Vector.nil => Vector.Forall_nil (vforall (P DTfloat))
-
- | Vector.cons h _ tl =>
- Vector.Forall_cons _ _ _
- ((fix F1 m (x:vector (DefinedFunction UnitAnn DTfloat) m) : vforall (P DTfloat) x :=
- match x with
- | Vector.nil => Vector.Forall_nil (P DTfloat)
- | Vector.cons h _ tl => Vector.Forall_cons _ _ _ (F DTfloat h _) (F1 _ tl)
- end) m h)
- (F2 _ _ tl)
- end) n m x)
- | Var v ann => fun hs => f3 v ann
- | Plus ann l r => fun hs => f4 ann l (F DTfloat l (proj1 hs)) r (F DTfloat r (proj2 hs))
- | Minus ann l r => fun hs => f5 ann l (F DTfloat l _) r (F DTfloat r _)
- | Times ann l r => fun hs => f6 ann l (F DTfloat l _) r (F DTfloat r _)
- | Divide ann l r => fun hs => f7 ann l (F DTfloat l _) r (F DTfloat r _)
- | Square ann e => fun hs => f8 ann e (F DTfloat e _)
- | Exp ann e => fun hs => f9 ann e (F DTfloat e _)
- | Log ann e => fun hs => f10 ann e (F DTfloat e _)
- | Abs ann e => fun hs => f11 ann e (F DTfloat e _)
- | Sign ann e => fun hs => f12 ann e (F DTfloat e _)
- | PSign ann e => fun hs => f13 ann e (F DTfloat e _)
- | Max ann l r => fun hs => f14 ann l (F DTfloat l _) r (F DTfloat r _)
- | @VectorDot _ n ann l r => fun hs => f15 n ann l (F (DTVector n) l _) r (F (DTVector n) r _)
- | @VectorSum _ n ann v => fun hs => f16 n ann v (F (DTVector n) v _)
- | @MatrixSum _ m n ann v => fun hs => f17 m n ann v (F (DTMatrix m n) v _)
- | @VectorElem _ n ann l i => fun hs => f18 n ann l (F (DTVector n) l _) i
- | @MatrixElem _ m n ann l i j => fun hs => f19 m n ann l (F (DTMatrix m n) l _) i j
- | @MatrixVectorMult _ m n ann l r => fun hs =>
- f20 m n ann l (F (DTMatrix m n) l _) r (F (DTVector n) r _)
- | @MatrixVectorAdd _ m n ann l r => fun hs =>
- f21 m n ann l (F (DTMatrix m n) l _) r (F (DTVector m) r _)
- | @MatrixMult _ m p n ann l r => fun hs =>
- f22 m p n ann l (F (DTMatrix m p) l _) r (F (DTMatrix p n) r _)
- | @VectorPlus _ n ann l r => fun hs => f23 n ann l (F (DTVector n) l _) r (F (DTVector n) r _)
- | @VectorMinus _ n ann l r => fun hs => f24 n ann l (F (DTVector n) l _) r (F (DTVector n) r _)
- | @MatrixPlus _ n m ann l r => fun hs => f25 n m ann l (F (DTMatrix n m) l _) r (F (DTMatrix n m) r _)
- | @MatrixMinus _ n m ann l r => fun hs =>
- f26 n m ann l (F (DTMatrix n m) l _) r (F (DTMatrix n m) r _)
- | @VectorScalMult _ n ann x l => fun hs => f27 n ann x (F DTfloat x _) l (F (DTVector n) l _)
- | @MatrixScalMult _ n m ann x l => fun hs => f28 n m ann x (F DTfloat x _) l (F (DTMatrix n m) l _)
- | @VectorApply _ n ann v s l => fun hs => f29 n ann v s _ (F DTfloat s _) l (F (DTVector n) l _)
- | @MatrixApply _ m n ann v s l => fun hs =>
- f30 m n ann v s _ (F DTfloat s _) l (F (DTMatrix m n) l _)
- | @VLossfun _ n ann v1 v2 s l r => fun hs =>
- f31 n ann v1 v2 s _ (F DTfloat s _) l (F (DTVector n) l _) r
- | @MLossfun _ m n ann v1 v2 s l r => fun hs =>
- f32 m n ann v1 v2 s _ (F DTfloat s _) l (F (DTMatrix m n) l _) r
- end); simpl in hs; intuition.
- - exact (proj1 (vforall_forall has_scalar_functions x) hs s).
- - rewrite vforall_forall in hs.
- specialize (hs s).
- rewrite vforall_forall in hs.
- specialize (hs s0).
- exact hs.
- Defined.
-
- Fixpoint is_df_rec_prop {Ann} {T}
- (prop : forall TT:definition_function_types,
- (DefinedFunction Ann TT) -> Prop)
- (df:DefinedFunction Ann T) {struct df}: Prop
- := prop T df /\
- match df with
- | Number _ _ => True
- | Constant _ _ _ => True
- | DVector n _ vec =>
- vforall (is_df_rec_prop prop) vec
- | DMatrix n m _ mat =>
- vforall (vforall (is_df_rec_prop prop)) mat
- | Var _ _ => True
- | Plus _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r)
- | Minus _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r)
- | Times _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r)
- | Divide _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r)
- | Square _ l => is_df_rec_prop prop l
- | Exp _ l => is_df_rec_prop prop l
- | Log _ l => is_df_rec_prop prop l
- | Abs _ l => is_df_rec_prop prop l
- | Sign _ l => is_df_rec_prop prop l
- | PSign _ l => is_df_rec_prop prop l
- | Max _ l r => (is_df_rec_prop prop l) /\ (is_df_rec_prop prop r)
- | VectorDot n _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | VectorSum _ _ l => is_df_rec_prop prop l
- | MatrixSum _ _ _ l => is_df_rec_prop prop l
- | VectorElem _ _ vec i => is_df_rec_prop prop vec
- | MatrixElem _ _ _ mat i j => is_df_rec_prop prop mat
- | MatrixVectorMult _ _ _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | MatrixVectorAdd _ _ _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | MatrixMult _ _ _ _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | VectorPlus _ _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | VectorMinus _ _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | MatrixPlus _ _ _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | MatrixMinus _ _ _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | VectorScalMult _ _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | MatrixScalMult _ _ _ l r => (is_df_rec_prop prop l) /\
- (is_df_rec_prop prop r)
- | VectorApply _ _ _ _ l => is_df_rec_prop prop l
- | MatrixApply _ _ _ _ _ l => is_df_rec_prop prop l
- | VLossfun _ _ _ _ _ l _ => is_df_rec_prop prop l
- | MLossfun _ _ _ _ _ _ l _ => is_df_rec_prop prop l
- end.
-
- Fixpoint df_strip_annotations {Ann} {T}
- (df:DefinedFunction Ann T) {struct df}: DefinedFunction UnitAnn T
- :=
- match df with
- | Number _ x1 => Number tt x1
- | Constant t _ x => Constant tt x
- | DVector n _ vec => DVector tt (vmap df_strip_annotations vec)
- | DMatrix n m _ mat => DMatrix tt (vmap (vmap df_strip_annotations) mat)
- | Var v _ => Var v tt
- | Plus _ l r => Plus tt (df_strip_annotations l) (df_strip_annotations r)
- | Minus _ l r => Minus tt (df_strip_annotations l) (df_strip_annotations r)
- | Times _ l r => Times tt (df_strip_annotations l) (df_strip_annotations r)
- | Divide _ l r => Divide tt (df_strip_annotations l) (df_strip_annotations r)
- | Square _ l => Square tt (df_strip_annotations l)
- | Exp _ l => Exp tt (df_strip_annotations l)
- | Log _ l => Log tt (df_strip_annotations l)
- | Abs _ l => Abs tt (df_strip_annotations l)
- | Sign _ l => Sign tt (df_strip_annotations l)
- | PSign _ l => PSign tt (df_strip_annotations l)
- | Max _ l r => Max tt (df_strip_annotations l) (df_strip_annotations r)
- | VectorDot n _ l r => VectorDot tt (df_strip_annotations l) (df_strip_annotations r)
- | VectorSum n _ l => VectorSum tt (df_strip_annotations l)
- | MatrixSum m n _ l => MatrixSum tt (df_strip_annotations l)
- | VectorElem n _ vec i => VectorElem tt (df_strip_annotations vec) i
- | MatrixElem m n _ mat i j => MatrixElem tt (df_strip_annotations mat) i j
- | MatrixVectorMult m n _ l r => MatrixVectorMult tt (df_strip_annotations l) (df_strip_annotations r)
- | MatrixVectorAdd m n _ l r => MatrixVectorAdd tt (df_strip_annotations l) (df_strip_annotations r)
- | MatrixMult m p n _ l r => MatrixMult tt (df_strip_annotations l) (df_strip_annotations r)
- | VectorPlus n _ l r => VectorPlus tt (df_strip_annotations l) (df_strip_annotations r)
- | VectorMinus n _ l r => VectorMinus tt (df_strip_annotations l) (df_strip_annotations r)
- | MatrixPlus m n _ l r => MatrixPlus tt (df_strip_annotations l) (df_strip_annotations r)
- | MatrixMinus m n _ l r => MatrixMinus tt (df_strip_annotations l) (df_strip_annotations r)
- | VectorScalMult n _ l r => VectorScalMult tt (df_strip_annotations l) (df_strip_annotations r)
- | MatrixScalMult m n _ l r => MatrixScalMult tt (df_strip_annotations l) (df_strip_annotations r)
- | VectorApply n _ v s l => VectorApply tt v (df_strip_annotations s) (df_strip_annotations l)
- | MatrixApply m n _ v s l => MatrixApply tt v (df_strip_annotations s) (df_strip_annotations l)
- | VLossfun n _ v1 v2 s l r => VLossfun tt v1 v2 (df_strip_annotations s) (df_strip_annotations l) r
- | MLossfun m n _ v1 v2 s l r => MLossfun tt v1 v2 (df_strip_annotations s) (df_strip_annotations l) r
- end.
-
- Require Import Program.
-
- Lemma df_strip_annotations_id {T} (df:DefinedFunction UnitAnn T) : df_strip_annotations df = df.
- Proof.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case; simpl; trivial
- ; destruct ann; trivial; try congruence.
- - Case "DVector"%string.
- f_equal.
- erewrite vmap_ext; [apply vmap_id | ]; intros.
- simpl.
- destruct H0 as [??]; subst.
- eapply H; eauto.
- - Case "DMatrix"%string.
- f_equal.
- erewrite vmap_ext; [apply vmap_id | ]; intros.
- simpl.
- erewrite vmap_ext; [apply vmap_id | ]; intros.
- simpl.
- destruct H0 as [??]; subst.
- destruct H1 as [??]; subst.
- eapply H; eauto.
- Qed.
-
- Definition df_eq_upto_annotations {Ann1 Ann2 T}
- (df1:DefinedFunction Ann1 T) (df2:DefinedFunction Ann2 T) : Prop
- := df_strip_annotations df1 = df_strip_annotations df2.
-
- Definition is_df_evalann_correct_top (σ:df_env) {T} (df:DefinedFunction EvalAnn T)
- := df_eval σ df = Some (get_annotation df).
-
- Definition is_df_evalann_correct (σ:df_env) {T} (df:DefinedFunction EvalAnn T)
- := is_df_rec_prop (@is_df_evalann_correct_top σ) df.
-
- Lemma is_df_rec_prop_top {Ann} {T}
- {prop : forall TT:definition_function_types,
- (DefinedFunction Ann TT) -> Prop}
- {df:DefinedFunction Ann T} :
- is_df_rec_prop prop df ->
- prop _ df.
- Proof.
- destruct df; simpl; tauto.
- Qed.
-
- Lemma df_eval_tree_correct {T Ann} (σ:df_env) (df:DefinedFunction Ann T) (dfann:DefinedFunction EvalAnn T):
- df_eval_tree σ df = Some dfann ->
- is_df_evalann_correct σ dfann.
- Proof.
- unfold is_df_evalann_correct, is_df_evalann_correct_top.
- revert dfann.
- DefinedFunction_cases (induction df) Case; simpl; intros dfann eqq
- ; try solve[case_eq (df_eval_tree σ df1)
- ; [intros adf1 a1eqq | intros a1eqq]
- ; rewrite a1eqq in eqq
- ; [| congruence]
- ; (case_eq (df_eval_tree σ df2)
- ; [intros adf2 a2eqq | intros a2eqq]
- ; rewrite a2eqq in eqq
- ; [| congruence]
- ; inversion eqq; simpl
- ; specialize (IHdf1 _ a1eqq)
- ; specialize (IHdf2 _ a2eqq)
- ; split; [| tauto]
- ; apply is_df_rec_prop_top in IHdf1
- ; apply is_df_rec_prop_top in IHdf2
- ; simpl in IHdf1, IHdf2
- ; rewrite IHdf1, IHdf2
- ; trivial)
-
- |
- case_eq (df_eval_tree σ df)
- ; [intros adf aeqq | intros aeqq]
- ; rewrite aeqq in eqq
- ; [| congruence]
- ; inversion eqq; simpl
- ; specialize (IHdf _ aeqq)
- ; split; [| tauto]
- ; apply is_df_rec_prop_top in IHdf
- ; simpl in IHdf
- ; rewrite IHdf
- ; trivial
- ].
-
- - Case "Number"%string.
- inversion eqq; subst.
- simpl; tauto.
- - Case "Constant"%string.
- inversion eqq; subst.
- simpl; tauto.
- - Case "DVector"%string.
- match_option_in eqq.
- invcs eqq.
- simpl.
- specialize (vectoro_to_ovector_forall_some_f eqq0)
- ; simpl
- ; clear eqq0; intros eqq0.
- split.
- + apply vectoro_to_ovector_forall_some_b_strong; intros i.
- specialize (H _ _ (eqq0 i)).
- apply is_df_rec_prop_top in H.
- simpl in *.
- rewrite vmap_nth; trivial.
- + apply vforall_forall; eauto.
- - Case "DMatrix"%string.
- match_option_in eqq.
- invcs eqq.
- simpl.
- unfold matrixo_to_omatrix in *.
- specialize (vectoro_to_ovector_forall_some_f eqq0)
- ; simpl
- ; clear eqq0; intros eqq0.
- split.
- + apply vectoro_to_ovector_forall_some_b_strong; intros i.
- apply vectoro_to_ovector_forall_some_b_strong; intros j.
- specialize (eqq0 i).
- specialize (vectoro_to_ovector_forall_some_f eqq0)
- ; simpl
- ; clear eqq0; intros eqq0.
- specialize (eqq0 j).
- specialize (H _ _ _ eqq0).
- apply is_df_rec_prop_top in H.
- simpl in *.
- repeat rewrite vmap_nth; trivial.
- + apply vforall_forall; intros.
- apply vforall_forall; intros.
- specialize (eqq0 i).
- specialize (vectoro_to_ovector_forall_some_f eqq0)
- ; simpl
- ; clear eqq0; intros eqq0.
- eauto.
- - Case "Var"%string.
- revert eqq.
- case_eq (vartlookup σ v) ; [| congruence].
- intros.
- inversion eqq; subst; simpl.
- rewrite H; tauto.
- - Case "VectorApply"%string.
- case_eq (df_eval_tree σ df2)
- ; [intros adf2 a2eqq | intros a2eqq]
- ; rewrite a2eqq in eqq
- ; [| congruence].
- specialize (IHdf2 _ a2eqq).
- match_option_in eqq.
- invcs eqq.
- simpl.
- split; trivial.
- apply is_df_rec_prop_top in IHdf2.
- simpl in IHdf2.
- rewrite IHdf2; trivial.
- - Case "MatrixApply"%string.
- case_eq (df_eval_tree σ df2)
- ; [intros adf2 a2eqq | intros a2eqq]
- ; rewrite a2eqq in eqq
- ; [| congruence].
- specialize (IHdf2 _ a2eqq).
- match_option_in eqq.
- invcs eqq.
- simpl.
- split; trivial.
- apply is_df_rec_prop_top in IHdf2.
- simpl in IHdf2.
- rewrite IHdf2; trivial.
- - Case "VLossfun"%string.
- case_eq (df_eval_tree σ df2)
- ; [intros adf2 a2eqq | intros a2eqq]
- ; rewrite a2eqq in eqq
- ; [| congruence].
- specialize (IHdf2 _ a2eqq).
- match_option_in eqq.
- invcs eqq.
- simpl.
- split; trivial.
- apply is_df_rec_prop_top in IHdf2.
- simpl in IHdf2.
- rewrite IHdf2, eqq0; trivial.
- - Case "MLossfun"%string.
- case_eq (df_eval_tree σ df2)
- ; [intros adf2 a2eqq | intros a2eqq]
- ; rewrite a2eqq in eqq
- ; [| congruence].
- specialize (IHdf2 _ a2eqq).
- match_option_in eqq.
- invcs eqq.
- simpl.
- split; trivial.
- apply is_df_rec_prop_top in IHdf2.
- simpl in IHdf2.
- rewrite IHdf2, eqq0; trivial.
- Qed.
-
-
- Lemma df_eval_tree_deriv_correct {T} {σ:df_env} {df:DefinedFunction EvalAnn T} :
- is_df_evalann_correct σ df ->
- forall (xv:var_type),
- (* let xv := (v, DTfloat) in *)
- df_eval_tree_deriv σ df xv = df_eval_deriv σ df xv.
- Proof.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case;
- intro iscor; destruct iscor;
- simpl; intros; trivial; unfold is_df_evalann_correct in *
- ; try solve
- [
- rewrite IHdf
- ; trivial
- |
- assert (is_df_evalann_correct_top σ df)
- ; [ apply is_df_rec_prop_top; trivial |
- unfold is_df_evalann_correct_top in H1
- ; rewrite H1
- ; rewrite IHdf
- ; trivial
- ]
- |
- rewrite IHdf1;
- [ rewrite IHdf2
- ; trivial
- ; tauto
- |
- tauto
- ]
- |
- rewrite IHdf1;
- [ assert (is_df_evalann_correct_top σ df2);
- [ apply is_df_rec_prop_top; trivial
- | unfold is_df_evalann_correct_top in H1;
- rewrite H1; trivial]
- | tauto]
- |
- destruct H0; rewrite IHdf1;
- [rewrite IHdf2;
- [assert (is_df_evalann_correct_top σ df1);
- [apply is_df_rec_prop_top; trivial
- | assert (is_df_evalann_correct_top σ df2);
- [ apply is_df_rec_prop_top; trivial
- |
- unfold is_df_evalann_correct_top in H2;
- unfold is_df_evalann_correct_top in H3;
- rewrite H2; rewrite H3; trivial ]]
- |
- tauto]
- |
- tauto]
- ].
- - Case "DVector"%string.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality.
- intro.
- apply H.
- destruct H0.
- rewrite vforall_forall in H1.
- eauto.
- - Case "DMatrix"%string.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality.
- intro.
- apply FunctionalExtensionality.functional_extensionality.
- intros.
- apply H.
- destruct H0.
- rewrite vforall_forall in H1.
- specialize (H1 x0).
- rewrite vforall_forall in H1.
- eauto.
- Qed.
-
- Lemma df_eval_tree_backprop_deriv_correct {T} (σ gradenv:df_env) (df:DefinedFunction EvalAnn T) (grad : definition_function_types_interp T) :
- is_df_evalann_correct σ df ->
- df_eval_tree_backprop_deriv σ df gradenv grad = df_eval_backprop_deriv σ df gradenv grad.
- Proof.
- revert gradenv grad.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case;
- intros; simpl;trivial
- ; try solve [
- destruct H;
- assert (is_df_evalann_correct σ df);
- [ unfold is_df_evalann_correct; trivial
- | apply is_df_rec_prop_top in H0;
- rewrite IHdf;
- [ unfold is_df_evalann_correct_top in H0;
- rewrite H0; trivial
- | trivial]]
- |
- destruct H;
- apply IHdf;
- unfold is_df_evalann_correct; trivial
- |
- destruct H; destruct H0; rewrite IHdf1;
- [ case_eq (df_eval_backprop_deriv σ df1 gradenv grad); [|congruence];
- intros;
- rewrite IHdf2; trivial
- | unfold is_df_evalann_correct; trivial]
-
- |
- destruct H;
- assert (is_df_evalann_correct σ df2);
- [ unfold is_df_evalann_correct; trivial
- | apply is_df_rec_prop_top in H0;
- unfold is_df_evalann_correct_top in H0;
- rewrite H0;
- match_destr;
- rewrite IHdf1; trivial]
- |
- destruct H; destruct H0; assert (is_df_evalann_correct σ df1);
- [ unfold is_df_evalann_correct; trivial
- | assert (is_df_evalann_correct σ df2);
- [ unfold is_df_evalann_correct; trivial
- | rewrite IHdf1; trivial;
- apply is_df_rec_prop_top in H0;
- apply is_df_rec_prop_top in H1;
- unfold is_df_evalann_correct_top in H0;
- unfold is_df_evalann_correct_top in H1;
- rewrite H0; rewrite H1;
- match_destr;
- rewrite IHdf2; trivial]]
- ].
- - Case "DVector"%string.
- destruct H0.
- rewrite vforall_forall in H1.
- unfold two_vector_env_iter_alt.
- f_equal; apply FunctionalExtensionality.functional_extensionality; intros
- ; apply FunctionalExtensionality.functional_extensionality ; intros.
- apply H; unfold is_df_evalann_correct; apply H1.
- - Case "DMatrix"%string.
- destruct H0.
- rewrite vforall_forall in H1.
- unfold two_matrix_env_iter_alt.
- f_equal; apply FunctionalExtensionality.functional_extensionality; intros
- ; apply FunctionalExtensionality.functional_extensionality; intros.
- f_equal; apply FunctionalExtensionality.functional_extensionality; intros
- ; apply FunctionalExtensionality.functional_extensionality; intros.
- apply H; unfold is_df_evalann_correct.
- specialize (H1 x0).
- rewrite vforall_forall in H1; apply H1.
- - Case "MatrixVectorAdd"%string.
- destruct H.
- destruct H0.
- assert (is_df_evalann_correct σ df1).
- unfold is_df_evalann_correct; trivial.
- assert (is_df_evalann_correct σ df2).
- unfold is_df_evalann_correct; trivial.
- rewrite IHdf1; trivial.
- match_destr.
- assert
- (list_env_iter
- (fun (i : {m' : nat | (m' < n)%nat}) (env : df_env) =>
- df_eval_tree_backprop_deriv σ df2 env (transpose grad i))
- (Some d) (bounded_seq0 n) =
- list_env_iter
- (fun (i : {m' : nat | (m' < n)%nat}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad i)) (Some d)
- (bounded_seq0 n)).
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite IHdf2; trivial.
- rewrite H4; trivial.
- Qed.
-
- Lemma df_eval_ignores_ann {Ann T} {σ:df_env}
- (df:DefinedFunction Ann T) :
- df_eval σ df = df_eval σ (df_strip_annotations df).
- Proof.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case; simpl; trivial
- ; try solve [
- rewrite IHdf; trivial
- |
- rewrite IHdf1;
- case_eq (df_eval σ (df_strip_annotations df1)); [|congruence];
- intros; rewrite IHdf2; trivial
- |
- rewrite IHdf1; trivial;
- case_eq (df_eval σ (df_strip_annotations df2)); [|congruence];
- intros; f_equal;
- apply FunctionalExtensionality.functional_extensionality; intros;
- f_equal; rewrite df_strip_annotations_id; trivial
- |
- rewrite IHdf1; trivial;
- case_eq (df_eval σ (df_strip_annotations df2)); [|congruence];
- intros; f_equal;
- rewrite df_strip_annotations_id; trivial
- ].
-
- - Case "DVector"%string.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite H.
- rewrite vmap_nth; trivial.
- - Case "DMatrix"%string.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H x0).
- rewrite H.
- rewrite vmap_nth.
- rewrite vmap_nth; trivial.
- Qed.
-
- Lemma df_eval_ignores_ann2 {Ann1 Ann2 T} {σ:df_env}
- (df1:DefinedFunction Ann1 T) (df2:DefinedFunction Ann2 T) :
- df_eq_upto_annotations df1 df2 ->
- df_eval σ df1 = df_eval σ df2.
- Proof.
- assert (df_eval σ df1 = df_eval σ (df_strip_annotations df1)) by apply df_eval_ignores_ann.
- assert (df_eval σ df2 = df_eval σ (df_strip_annotations df2)) by apply df_eval_ignores_ann.
- congruence.
- Qed.
-
- Lemma df_eval_deriv_ignores_ann {Ann T} {σ:df_env}
- (df:DefinedFunction Ann T) :
- forall (xv:var_type),
- df_eval_deriv σ df xv = df_eval_deriv σ (df_strip_annotations df) xv.
- Proof.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case; simpl; trivial
- ; try solve
- [
- intro; rewrite IHdf1;
- case_eq (df_eval_deriv σ (df_strip_annotations df1) xv); [|congruence];
- intros;
- rewrite IHdf2; trivial
- |
- intro; rewrite df_eval_ignores_ann;
- case_eq (df_eval σ (df_strip_annotations df1)); [|congruence];
- intros; rewrite IHdf1; intros;
- case_eq (df_eval_deriv σ (df_strip_annotations df1) xv); [|congruence];
- intros; rewrite df_eval_ignores_ann;
- case_eq (df_eval σ (df_strip_annotations df2)); [|congruence];
- intros; rewrite IHdf2; trivial
- |
- intros; rewrite df_eval_ignores_ann;
- case_eq (df_eval σ (df_strip_annotations df)); [|congruence];
- intros; rewrite IHdf; intros;
- case_eq (df_eval_deriv σ (df_strip_annotations df) xv); [|congruence];
- trivial
- |
- intros; rewrite IHdf; trivial
- |
- intro; rewrite df_eval_ignores_ann;
- case_eq (df_eval σ (df_strip_annotations df2)); [|congruence];
- intros; rewrite IHdf1; intros;
- rewrite df_strip_annotations_id; trivial
- ].
-
- - Case "DVector"%string.
- intros.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite H.
- rewrite vmap_nth; trivial.
- - Case "DMatrix"%string.
- intros.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H x0).
- rewrite H.
- rewrite vmap_nth.
- rewrite vmap_nth; trivial.
- - Case "Max"%string.
- intro; rewrite df_eval_ignores_ann.
- case_eq (df_eval σ (df_strip_annotations df1)); [|congruence].
- intros.
- rewrite df_eval_ignores_ann.
- case_eq (df_eval σ (df_strip_annotations df2)); [|congruence].
- intros.
- rewrite IHdf1.
- rewrite IHdf2; trivial.
- Qed.
-
- Lemma df_eval_deriv_ignores_ann2 {Ann1 Ann2 T} {σ:df_env}
- (df1:DefinedFunction Ann1 T) (df2:DefinedFunction Ann2 T) :
- forall (xv:var_type),
- df_eq_upto_annotations df1 df2 ->
- df_eval_deriv σ df1 xv = df_eval_deriv σ df2 xv.
- Proof.
- intro.
- assert (df_eval_deriv σ df1 xv = df_eval_deriv σ (df_strip_annotations df1) xv) by apply df_eval_deriv_ignores_ann.
- assert (df_eval_deriv σ df2 xv = df_eval_deriv σ (df_strip_annotations df2) xv) by apply df_eval_deriv_ignores_ann.
- congruence.
- Qed.
-
- Lemma is_scalar_function_scalar {Ann} {T} (df:DefinedFunction Ann T) :
- is_scalar_function df -> is_scalar_df_type T.
- Proof.
- induction df; simpl; trivial.
- Qed.
-
-
-
- Definition definition_function_types_map_base (f:Type->Type) (dft:definition_function_types): Type
- := match dft with
- | DTfloat => f float
- | DTVector n => Vector (f float) n
- | DTMatrix m n => Matrix (f float) m n
- end.
-
- Definition definition_function_types_subgradient (dft:definition_function_types)
- := definition_function_types_map_base (fun t => list (list t)) dft.
-
-
- Definition df_eval_gradient {T} σ (df:DefinedFunction UnitAnn T) (lv:list var_type) : option (list (definition_function_types_interp T))
- := listo_to_olist (map (df_eval_deriv σ df) lv).
-
- Definition combine_prod (l1 l2 : list (list float)) : list (list (float * float))
- := let l12 := list_prod l1 l2
- in map (fun '(x,y) => combine x y) l12.
-(*
- Fixpoint df_eval_subgradient {dft:definition_function_types} (σ:df_env) (df:DefinedFunction dft) (lv:list SubVar) : option (definition_function_types_subgradient dft)
- := (match df with
- | Number _ => Some ((map (fun _ => 0) lv) :: nil)
- | DVector n v => vectoro_to_ovector (vmap (fun x => df_eval_subgradient σ x lv) v)
- | DMatrix n m df => matrixo_to_omatrix (vmap (fun x => vmap (fun y => df_eval_subgradient σ y lv) x) df)
- | Var x => Some ((map (fun v => if x == v then 1 else 0) lv) :: nil)
- | Plus l r =>
- match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with
- | Some ld, Some rd => Some (map (map (fun '(x, y) => x+y)) (combine_prod ld rd))
- | _, _ => None
- end
- | Minus l r =>
- match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with
- | Some ld, Some rd => Some (map (map (fun '(x, y) => x-y)) (combine_prod ld rd))
- | _, _ => None
- end
- | Times l r =>
- match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with
- | Some le, Some ld, Some re, Some rd =>
- Some (map (map (fun '(lp,rp) => lp*re + le*rp)) (combine_prod ld rd))
- | _, _, _, _ => None
- end
- | Divide l r =>
- match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with
- | Some le, Some ld, Some re, Some rd =>
- Some (map (map (fun '(lp,rp) => (lp*re - le*rp)/(re * re))) (combine_prod ld rd))
- | _, _, _, _ => None
- end
- | Square e =>
- match df_eval σ e, df_eval_subgradient σ e lv with
- | Some ee, Some ed => Some (map (map (fun pd => 2 * ee * pd)) ed)
- | _, _ => None
- end
- | Exp e =>
- match df_eval σ e, df_eval_subgradient σ e lv with
- | Some ee, Some ed => Some (map (map (fun pd => pd * Fexp ee)) ed)
- | _, _ => None
- end
- | Log e =>
- match df_eval σ e, df_eval_subgradient σ e lv with
- | Some ee, Some ed => Some (map (map (fun pd => (pd / ee))) ed)
- | _, _ => None
- end
- | Abs e =>
- match df_eval σ e, df_eval_subgradient σ e lv with
- | Some ee, Some ed =>
- if Feq ee 0 then Some (ed ++ (map (map (fun ep => -ep)) ed))
- else Some (map (map (fun ed => (ed * (sign ee)))) ed)
- | _, _ => None
- end
- | Sign e =>
- match df_eval σ e with
- | Some ee => Some ((map (fun _ => 0) lv) :: nil )
- | _ => None
- end
- | PSign e =>
- match df_eval σ e with
- | Some ee => Some ((map (fun _ => 0) lv) :: nil )
- | _ => None
- end
- | Max l r =>
- match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with
- | Some le, Some ld, Some re, Some rd =>
- if Feq le re then Some (ld ++ rd)
- else if le > re then Some ld
- else Some rd
- | _, _, _, _ => None
- end
- | VectorElem n l i =>
- match (df_eval_subgradient σ l lv) with
- | Some l' => Some (l' i)
- | _ => None
- end
- | MatrixElem m n l i j =>
- match (df_eval_subgradient σ l lv) with
- | Some l' => Some (l' i j)
- | _ => None
- end
- | VectorSum n l =>
- match df_eval_subgradient σ l lv with
- | Some l' =>
- Some (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp))
- (combine_prod a b))
- ((map (fun _ => 0) lv)::nil) l')
- | _ => None
- end
- | VectorDot n l r =>
- match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with
- | Some le, Some ld, Some re, Some rd =>
- Some (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp))
- (combine_prod a b))
- ((map (fun _ => 0) lv)::nil)
- (fun i => map (map (fun '(lp,rp) => lp*(re i) + (le i)*rp))
- (combine_prod (ld i) (rd i))))
- | _, _, _, _ => None
- end
- | VectorScalMult n x r =>
- match df_eval σ x, df_eval_subgradient σ x lv, df_eval σ r, df_eval_subgradient σ r lv with
- | Some xe, Some xd, Some re, Some rd =>
- Some (fun j => map (map (fun '(xp,rp) => xe * rp + xp * (re j))) (combine_prod xd (rd j)))
- | _, _, _, _ => None
- end
- | MatrixScalMult n m x r =>
- match df_eval σ x, df_eval_subgradient σ x lv, df_eval σ r, df_eval_subgradient σ r lv with
- | Some xe, Some xd, Some re, Some rd =>
- Some (fun i j => map (map (fun '(xp,rp) => xe * rp + xp * (re i j))) (combine_prod xd (rd i j)))
-
- | _, _, _, _ => None
- end
- | MatrixVectorMult n m l r =>
- match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with
- | Some le, Some ld, Some re, Some rd =>
- Some (fun i =>
- (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp))
- (combine_prod a b))
- ((map (fun _ => 0) lv)::nil)
- (fun j => map (map (fun '(lp,rp) => lp*(re j) + (le i j)*rp))
- (combine_prod (ld i j) (rd j)))))
- | _, _, _, _ => None
- end
- | MatrixMult n m p l r =>
- match df_eval σ l, df_eval_subgradient σ l lv, df_eval σ r, df_eval_subgradient σ r lv with
- | Some le, Some ld, Some re, Some rd =>
- Some (fun i k =>
- (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp))
- (combine_prod a b))
- ((map (fun _ => 0) lv)::nil)
- (fun j => map (map (fun '(lp,rp) => lp*(re j k) + (le i j)*rp))
- (combine_prod (ld i j) (rd j k)))))
- | _, _, _, _ => None
- end
- | VectorPlus n l r =>
- match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with
- | Some ld, Some rd => Some (fun i => (map (map (fun '(x, y) => x+y)) (combine_prod (ld i) (rd i))))
- | _, _ => None
- end
- | VectorMinus n l r =>
- match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with
- | Some ld, Some rd => Some (fun i => (map (map (fun '(x, y) => x-y)) (combine_prod (ld i) (rd i))))
- | _, _ => None
- end
- | MatrixPlus n m l r =>
- match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with
- | Some ld, Some rd => Some (fun i j => (map (map (fun '(x, y) => x+y)) (combine_prod (ld i j) (rd i j))))
- | _, _ => None
- end
- | MatrixMinus n m l r =>
- match df_eval_subgradient σ l lv, df_eval_subgradient σ r lv with
- | Some ld, Some rd => Some (fun i j => (map (map (fun '(x, y) => x-y)) (combine_prod (ld i j) (rd i j))))
- | _, _ => None
- end
- | VectorApply n x s r =>
- match df_eval σ r, df_eval_subgradient σ r lv with
- | Some re, Some rd =>
- vectoro_to_ovector
- (fun i => match df_eval_subgradient (cons (x, re i) σ) s lv with
- | Some sd =>
- Some (map (map (fun '(x, y) => x*y)) (combine_prod (rd i) sd))
- | _ => None
- end)
- | _, _ => None
- end
- | Lossfun n v1 v2 s l r =>
- match df_eval σ l, df_eval_subgradient σ l lv with
- | Some le, Some ld =>
- match (vectoro_to_ovector
- (fun i => match df_eval_subgradient (cons (v1, (le i)) (cons (v2, r i) σ)) s lv with
- | Some sd => Some (map (map (fun '(x, y) => x*y)) (combine_prod (ld i) sd))
- | _ => None
- end)) with
- | Some vv => Some (vector_fold_right (fun a b => map (map (fun '(xp,rp) => xp + rp))
- (combine_prod a b))
- ((map (fun _ => 0) lv)::nil) vv)
- | _ => None
- end
- | _, _ => None
- end
- end).
-*)
- End deriv2.
-
- Definition dft_one (dft:definition_function_types) : definition_function_types_interp dft
- := match dft with
- | DTfloat => 1
- | DTVector n => fun _ => 1
- | DTMatrix m n => fun _ _ => 1
- end.
-
- Section scalar_ind.
-
- Fixpoint is_scalar_function_ind_gen {Ann}
- {P:forall {T}, DefinedFunction Ann T->Prop}
- (fnumber:forall ann x, P (Number ann x))
- (fconstant:forall (ann:Ann DTfloat) x, P (@Constant _ DTfloat ann x))
- (fvar:forall sv ann, P (@Var _ (sv,DTfloat) ann))
- (fplus:forall a l r, P l -> P r -> P (Plus a l r))
- (fminus:forall a l r, P l -> P r -> P (Minus a l r))
- (ftimes:forall a l r, P l -> P r -> P (Times a l r))
- (fdivide:forall a l r, P l -> P r -> P (Divide a l r))
- (fsquare:forall a e, P e -> P (Square a e))
- (fexp:forall a e, P e -> P (Exp a e))
- (flog:forall a e, P e -> P (Log a e))
- (fabs:forall a e, P e -> P (Abs a e))
- (fsign:forall a e, P e -> P (Sign a e))
- (fpsign:forall a e, P e -> P (PSign a e))
- (fmax:forall a l r, P l -> P r -> P (Max a l r))
- {T}
- (df:DefinedFunction Ann T) {struct df} : is_scalar_function df -> P df.
- Proof.
- induction df; simpl; intros isc; try tauto.
- - apply fnumber.
- - destruct t; simpl in isc; try tauto.
- apply fconstant.
- - destruct v.
- destruct d; simpl in isc; try tauto.
- apply fvar.
- - apply fplus.
- + apply IHdf1; tauto.
- + apply IHdf2; tauto.
- - apply fminus.
- + apply IHdf1; tauto.
- + apply IHdf2; tauto.
- - apply ftimes.
- + apply IHdf1; tauto.
- + apply IHdf2; tauto.
- - apply fdivide.
- + apply IHdf1; tauto.
- + apply IHdf2; tauto.
- - apply fsquare.
- + apply IHdf; tauto.
- - apply fexp.
- + apply IHdf; tauto.
- - apply flog.
- + apply IHdf; tauto.
- - apply fabs.
- + apply IHdf; tauto.
- - apply fsign.
- + apply IHdf; tauto.
- - apply fpsign.
- + apply IHdf; tauto.
- - apply fmax.
- + apply IHdf1; tauto.
- + apply IHdf2; tauto.
- Qed.
-
- Definition is_scalar_function_ind {Ann}
- {P:DefinedFunction Ann DTfloat->Prop}
- (fnumber:forall ann x, P (Number ann x))
- (fconstant:forall (ann:Ann DTfloat) x, P (@Constant _ DTfloat ann x))
- (fvar:forall sv ann, P (@Var _ (sv,DTfloat) ann))
- (fplus:forall a l r, P l -> P r -> P (Plus a l r))
- (fminus:forall a l r, P l -> P r -> P (Minus a l r))
- (ftimes:forall a l r, P l -> P r -> P (Times a l r))
- (fdivide:forall a l r, P l -> P r -> P (Divide a l r))
- (fsquare:forall a e, P e -> P (Square a e))
- (fexp:forall a e, P e -> P (Exp a e))
- (flog:forall a e, P e -> P (Log a e))
- (fabs:forall a e, P e -> P (Abs a e))
- (fsign:forall a e, P e -> P (Sign a e))
- (fpsign:forall a e, P e -> P (PSign a e))
- (fmax:forall a l r, P l -> P r -> P (Max a l r))
- (df:DefinedFunction Ann DTfloat) : is_scalar_function df -> P df.
- Proof.
- apply (@is_scalar_function_ind_gen _ (fun t => match t with
- | DTfloat => fun df => P df
- | _ => fun _ => False
- end)); trivial.
- Qed.
-
- Definition vartlookup_eq (l1 l2:df_env) : Prop := forall a, vartlookup l1 a = vartlookup l2 a.
-
- Global Instance vartlookup_eq_equiv : Equivalence vartlookup_eq.
- Proof.
- unfold vartlookup_eq.
- constructor; red.
- - intros; reflexivity.
- - intros; eauto.
- - intro; etransitivity; eauto.
- Qed.
-
- End scalar_ind.
-
- Lemma lookup_update (xv : var_type) (gradenv : df_env)
- (val : definition_function_types_interp (snd xv)) :
- vartlookup (vart_update gradenv xv val) xv = Some val.
- Proof.
- induction gradenv; simpl.
- - destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence].
- refl_simpler; simpl; trivial.
- - destruct a; simpl.
- case_eq (@equiv_dec var_type _ _ _ xv x); simpl; intros.
- + destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence].
- refl_simpler; simpl; trivial.
- + rewrite H; trivial.
- Qed.
-
- Lemma lookup_update_neq (xv1 xv2 : var_type) (gradenv : df_env)
- (val : definition_function_types_interp (snd xv1)) : xv1 <> xv2 ->
- vartlookup (vart_update gradenv xv1 val) xv2 = vartlookup gradenv xv2.
- Proof.
- intros neq.
- induction gradenv; simpl.
- - destruct (@equiv_dec var_type _ _ _ xv2 xv1); congruence.
- - destruct a; simpl.
- case_eq (@equiv_dec var_type _ _ _ xv1 x); simpl; intros.
- + destruct (@equiv_dec var_type _ _ _ xv2 xv1); [congruence | ].
- destruct (@equiv_dec var_type _ _ _ xv2 x); congruence.
- + destruct (@equiv_dec var_type _ _ _ xv2 x); congruence.
- Qed.
-
- Lemma lookup_update2 (xv : var_type) (gradenv : df_env)
- (val : definition_function_types_interp (snd xv)) :
- vartlookup ((mk_env_entry xv val) :: gradenv) xv = Some val.
- Proof.
- induction gradenv; simpl.
- - destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence].
- refl_simpler; simpl; trivial.
- - destruct a; simpl.
- case_eq (@equiv_dec var_type _ _ _ xv x); simpl; intros.
- + destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence].
- refl_simpler; simpl; trivial.
- + destruct (@equiv_dec var_type _ _ _ xv xv); [| congruence].
- refl_simpler; simpl; trivial.
- Qed.
-
-
-Tactic Notation "DefinedFunction_scalar_cases" tactic(first) ident(c) :=
- first;
- [ Case_aux c "Number"%string
- | Case_aux c "Constant"%string
- | Case_aux c "Var"%string
- | Case_aux c "Plus"%string
- | Case_aux c "Minus"%string
- | Case_aux c "Times"%string
- | Case_aux c "Divide"%string
- | Case_aux c "Square"%string
- | Case_aux c "Exp"%string
- | Case_aux c "Log"%string
- | Case_aux c "Abs"%string
- | Case_aux c "Sign"%string
- | Case_aux c "PSign"%string
- | Case_aux c "Max"%string].
-
- Lemma df_eval_backprop_deriv_preserves_lookup_not_none {Ann T} {env} {grad gradenv d} {df:DefinedFunction Ann T} :
- df_eval_backprop_deriv env df gradenv grad = Some d ->
- forall xv,
- vartlookup gradenv xv <> None ->
- vartlookup d xv <> None.
- Proof.
- simpl.
- revert grad gradenv d.
- DefinedFunction_cases (induction df) Case; simpl.
- - Case "Number"%string; intros; inversion H; subst; easy.
- - Case "Constant"%string; intros; inversion H; subst; easy.
- - Case "DVector"%string.
- intros grad.
- unfold two_vector_env_iter_alt.
- induction (bounded_seq0 n).
- simpl.
- intros.
- inversion H0; subst; trivial.
- simpl.
- intros gradenv d.
- case_eq (df_eval_backprop_deriv env (x a) gradenv (grad a)).
- intros.
- specialize (H a (grad a) gradenv d0).
- specialize (IHl d0 d).
- apply IHl; trivial.
- apply H; trivial.
- intros.
- assert (list_env_iter
- (fun (i : {n' : nat | (n' < n)%nat}) (env0 : df_env) =>
- df_eval_backprop_deriv env (x i) env0 (grad i)) None l = None)
- by apply list_env_iter_none.
- intros; rewrite H1 in H3; discriminate.
- - Case "DMatrix"%string.
- intros grad.
- unfold two_matrix_env_iter_alt.
- induction (bounded_seq0 n); simpl.
- { intros; inversion H0; subst; trivial. }
- intros gradenv d eqq.
- case_eq ((list_env_iter
- (fun (j : {m' : nat | (m' < m)%nat}) (env0 : df_env) =>
- df_eval_backprop_deriv env (x a j) env0 (grad a j)) (Some gradenv)
- (bounded_seq0 m)))
- ; [ intros dd ddeqq | intros ddeqq]
- ; rewrite ddeqq in eqq
- ; simpl in eqq
- ; [| destruct l; simpl; discriminate].
- specialize (IHl _ _ eqq).
- cut (forall xv : var_type, vartlookup gradenv xv <> None -> vartlookup dd xv <> None)
- ; [ eauto | ].
- clear d IHl eqq.
- revert gradenv dd ddeqq.
- induction (bounded_seq0 m); simpl
- ; intros gradenv dd ddeqq
- ; simpl in ddeqq.
- { inversion ddeqq; subst; trivial. }
- case_eq (df_eval_backprop_deriv env (x a a0) gradenv (grad a a0))
- ; [intros dd2 ddeqq2 | intros ddeqq2]
- ; rewrite ddeqq2 in ddeqq
- ; simpl in ddeqq
- ; [| destruct l0; simpl; discriminate].
- eauto.
- - Case "Var"%string.
- intros.
- destruct (vartlookup gradenv v) ; [|congruence].
- intros.
- inversion H.
- destruct (vart_dec v xv).
- + subst; rewrite lookup_update.
- discriminate.
- + rewrite lookup_update_neq; trivial.
- - Case "Plus"%string.
- intros grad gradenv.
- case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence].
- intros.
- specialize (IHdf1 grad gradenv d).
- specialize (IHdf2 grad d d0).
- apply IHdf2; trivial.
- apply IHdf1; trivial.
- - Case "Minus"%string.
- intros grad gradenv.
- case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence].
- intros.
- specialize (IHdf1 grad gradenv d).
- specialize (IHdf2 (-grad) d d0).
- apply IHdf2; trivial.
- apply IHdf1; trivial.
- - Case "Times"%string.
- intros grad gradenv d.
- destruct (df_eval env df1) ; [|congruence].
- destruct (df_eval env df2) ; [|congruence].
- case_eq (df_eval_backprop_deriv env df1 gradenv (d1 * grad)) ; [|congruence].
- intros d2.
- specialize (IHdf1 (d1 * grad) gradenv d2).
- specialize (IHdf2 (d0 * grad) d2 d).
- intros.
- apply IHdf2; trivial.
- apply IHdf1; trivial.
- - Case "Divide"%string.
- intros grad gradenv d.
- destruct (df_eval env df1) ; [|congruence].
- destruct (df_eval env df2) ; [|congruence].
- case_eq (df_eval_backprop_deriv env df1 gradenv (grad / d1)) ; [|congruence].
- intros d2.
- specialize (IHdf1 (grad / d1) gradenv d2).
- specialize (IHdf2 (- d0 / (d1 * d1) * grad) d2 d).
- intros.
- apply IHdf2; trivial.
- apply IHdf1; trivial.
- - Case "Square"%string; intros.
- destruct (df_eval env df) ; [|congruence].
- specialize (IHdf (2 * d0 * grad) gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "Exp"%string; intros.
- destruct (df_eval env df) ; [|congruence].
- specialize (IHdf (grad * Fexp d0) gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "Log"%string; intros.
- destruct (df_eval env df) ; [|congruence].
- specialize (IHdf (grad / d0) gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "Abs"%string; intros.
- destruct (df_eval env df) ; [|congruence].
- specialize (IHdf (grad * sign d0) gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "Sign"%string; intros.
- specialize (IHdf 0 gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "PSign"%string; intros.
- specialize (IHdf 0 gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "Max"%string; intros.
- destruct (df_eval env df1) ; [|congruence].
- destruct (df_eval env df2) ; [|congruence].
- destruct (d0 <= d1).
- specialize (IHdf2 grad gradenv d).
- apply IHdf2.
- apply H.
- trivial.
- specialize (IHdf1 grad gradenv d).
- apply IHdf1.
- apply H.
- trivial.
- - Case "VectorDot"%string.
- intros grad gradenv d.
- destruct (df_eval env df1) ; [|congruence].
- destruct (df_eval env df2) ; [|congruence].
- case_eq (df_eval_backprop_deriv env df1 gradenv
- (vmap (fun rv : float => rv * grad) d1)); [|congruence].
- intros.
- specialize (IHdf1 (vmap (fun rv : float => rv * grad) d1) gradenv d2).
- specialize (IHdf2 (vmap (fun lv : float => lv * grad) d0) d2 d).
- apply IHdf2.
- apply H0.
- apply IHdf1.
- apply H.
- apply H1.
- - Case "VectorSum"%string.
- intros.
- specialize (IHdf (ConstVector n grad) gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "MatrixSum"%string.
- intros.
- specialize (IHdf (ConstMatrix m n grad) gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "VectorElem"%string.
- intros.
- specialize (IHdf (fun k : {n' : nat | (n' < n)%nat} =>
- if equiv_dec (proj1_sig k) (proj1_sig i)
- then grad else 0) gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "MatrixElem"%string.
- intros.
- specialize (IHdf (fun (k1 : {n' : nat | (n' < m)%nat})
- (k2 : {m' : nat | (m' < n)%nat}) =>
- if equiv_dec (proj1_sig k1) (proj1_sig i)
- then if equiv_dec (proj1_sig k2) (proj1_sig j) then grad else 0
- else 0) gradenv d).
- apply IHdf.
- apply H.
- apply H0.
- - Case "MatrixVectorMult"%string.
- intros grad gradenv d.
- destruct (df_eval env df1) ; [|congruence].
- destruct (df_eval env df2) ; [|congruence].
- case_eq (df_eval_backprop_deriv env df1 gradenv
- (fun (i : {n' : nat | (n' < m)%nat})
- (j : {m' : nat | (m' < n)%nat}) => grad i * d1 j)) ; [|congruence].
- intros d2.
- specialize (IHdf1 (fun (i : {n' : nat | (n' < m)%nat})
- (j : {m' : nat | (m' < n)%nat}) => grad i * d1 j) gradenv d2).
- specialize (IHdf2 (matrix_vector_mult
- (fun (i : {n' : nat | (n' < n)%nat})
- (j : {m' : nat | (m' < m)%nat}) => d0 j i) grad) d2 d).
- intros.
- apply IHdf2; trivial.
- apply IHdf1; trivial.
- - Case "MatrixVectorAdd"%string.
- intros grad gradenv d.
- case_eq (df_eval_backprop_deriv env df1 gradenv grad); [|congruence].
- intros d0 casedf1.
- specialize (IHdf1 _ _ _ casedf1).
- clear casedf1.
- revert gradenv d d0 IHdf1.
- induction (bounded_seq0 n).
- + simpl.
- intros.
- inversion H; subst.
- eauto.
- + intros gradenv d d0 d0eqq.
- simpl.
- case_eq (df_eval_backprop_deriv env df2 d0 (transpose grad a)); simpl
- ; [intros ? eqq1 | intros eqq1].
- * intros.
- { apply (IHl d0 _ d1); trivial.
- - eapply IHdf2; eauto.
- - eauto.
- }
- * destruct l; simpl; discriminate.
- - Case "MatrixMult"%string.
- intros grad gradenv d.
- destruct (df_eval env df1) ; [|congruence].
- destruct (df_eval env df2) ; [|congruence].
- case_eq (df_eval_backprop_deriv env df1 gradenv
- (matrix_mult grad
- (fun (i : {n' : nat | (n' < n)%nat})
- (j : {m' : nat | (m' < p)%nat}) => d1 j i)))
- ; [|congruence].
- intros d2.
- specialize (IHdf1 (matrix_mult grad
- (fun (i : {n' : nat | (n' < n)%nat})
- (j : {m' : nat | (m' < p)%nat}) => d1 j i))
- gradenv d2).
- specialize (IHdf2 (matrix_mult
- (fun (i : {n' : nat | (n' < p)%nat})
- (j : {m' : nat | (m' < m)%nat}) => d0 j i) grad) d2 d).
- intros.
- apply IHdf2; trivial.
- apply IHdf1; trivial.
- - Case "VectorPlus"%string.
- intros grad gradenv d.
- case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence].
- intros.
- specialize (IHdf1 grad gradenv d0).
- specialize (IHdf2 grad d0 d).
- apply IHdf2.
- apply H0.
- apply IHdf1.
- apply H.
- apply H1.
- - Case "VectorMinus"%string.
- intros grad gradenv d.
- case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence].
- intros.
- specialize (IHdf1 grad gradenv d0).
- specialize (IHdf2 (fun i : {n' : nat | (n' < n)%nat} => - grad i) d0 d).
- apply IHdf2.
- apply H0.
- apply IHdf1.
- apply H.
- apply H1.
- - Case "MatrixPlus"%string.
- intros grad gradenv d.
- case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence].
- intros.
- specialize (IHdf1 grad gradenv d0).
- specialize (IHdf2 grad d0 d).
- apply IHdf2.
- apply H0.
- apply IHdf1.
- apply H.
- apply H1.
- - Case "MatrixMinus"%string.
- intros grad gradenv d.
- case_eq (df_eval_backprop_deriv env df1 gradenv grad) ; [|congruence].
- intros.
- specialize (IHdf1 grad gradenv d0).
- specialize (IHdf2 (fun (i : {n' : nat | (n' < n)%nat})
- (j : {m' : nat | (m' < m)%nat}) => - grad i j)
- d0 d).
- apply IHdf2.
- apply H0.
- apply IHdf1.
- apply H.
- apply H1.
- - Case "VectorScalMult"%string.
- intros grad gradenv d.
- destruct (df_eval env df1) ; [|congruence].
- destruct (df_eval env df2) ; [|congruence].
- case_eq (df_eval_backprop_deriv env df1 gradenv
- (vsum (fun j : {n' : nat | (n' < n)%nat} => d1 j * grad j)))
- ; [|congruence].
- intros d2.
- specialize (IHdf1 (vsum (fun j : {n' : nat | (n' < n)%nat} => d1 j * grad j))
- gradenv d2).
- specialize (IHdf2 (fun j : {n' : nat | (n' < n)%nat} => d0 * grad j)
- d2 d).
- intros.
- apply IHdf2; trivial.
- apply IHdf1; trivial.
- - Case "MatrixScalMult"%string.
- intros grad gradenv d.
- destruct (df_eval env df1) ; [|congruence].
- destruct (df_eval env df2) ; [|congruence].
- case_eq (df_eval_backprop_deriv env df1 gradenv
- (msum
- (fun (i : {n' : nat | (n' < n)%nat})
- (j : {m' : nat | (m' < m)%nat}) =>
- d1 i j * grad i j)))
- ; [|congruence].
- intros d2.
- specialize (IHdf1 (msum
- (fun (i : {n' : nat | (n' < n)%nat})
- (j : {m' : nat | (m' < m)%nat}) =>
- d1 i j * grad i j))
- gradenv d2).
- specialize (IHdf2 (fun (i : {n' : nat | (n' < n)%nat})
- (j : {m' : nat | (m' < m)%nat}) => grad i j * d0)
- d2 d).
- intros.
- apply IHdf2; trivial.
- apply IHdf1; trivial.
- - Case "VectorApply"%string.
- intros grad gradenv d.
- destruct (df_eval env df2) ; [|congruence].
- simpl in *.
- match_destr; simpl; eauto.
- - Case "MatrixApply"%string.
- intros grad gradenv d.
- destruct (df_eval env df2) ; [|congruence].
- match_destr; simpl; eauto.
- - Case "VLossfun"%string.
- intros grad gradenv d.
- destruct (df_eval env df2) ; [|congruence].
- match_destr; simpl; eauto.
- - Case "MLossfun"%string.
- intros grad gradenv d.
- destruct (df_eval env df2) ; [|congruence].
- match_destr; simpl; eauto.
- Qed.
-
- Definition df_eval_deriv_gen_top {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v: var_type) :
- option (lifted_type (definition_function_types_interp T) (snd v)) :=
- match (snd v) as vt return option (lifted_type (definition_function_types_interp T) vt) with
- | DTfloat => df_eval_deriv_genvar σ df ((mk_env_entry (fst v, DTfloat) 1)::nil)
- | DTVector n =>
- vectoro_to_ovector
- (fun i => df_eval_deriv_genvar σ df ((mk_env_entry (fst v, DTVector n) (UnitVector n i))::nil))
- | DTMatrix n m =>
- matrixo_to_omatrix
- (fun i j => df_eval_deriv_genvar σ df ((mk_env_entry (fst v, DTMatrix n m) (UnitMatrix n m i j))::nil))
- end.
-
- Program Definition subvar (x : var_type) (grad_env:df_env) :=
- (match snd x as y return snd x = y ->
- definition_function_types_interp y ->
- definition_function_types_interp y with
- | DTfloat => fun pf grad => match vartlookup grad_env x with
- | Some val => ((coerce _ val):float) - grad
- | _ => Fopp grad
- end
- | DTVector n => fun pf grad => match vartlookup grad_env x with
- | Some val => fun i => (((coerce _ val):Vector float n) i) - (grad i)
- | _ => vmap Fopp grad
- end
- | DTMatrix m n => fun pf grad => match vartlookup grad_env x with
- | Some val => fun i j => (((coerce _ val):Matrix float m n) i j) - (grad i j)
- | _ => vmap (vmap Fopp) grad
- end
- end) (eq_refl _).
- Next Obligation.
- rewrite pf; reflexivity.
- Defined.
- Next Obligation.
- rewrite pf; reflexivity.
- Defined.
- Next Obligation.
- rewrite pf; reflexivity.
- Defined.
-
- Definition df_eval_backprop_delta {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v: var_type) (grad_env:df_env) (grad: definition_function_types_interp T) :
- option (definition_function_types_interp (snd v)) :=
- match vartlookup grad_env v with
- | Some old =>
- lift (fun e => subvar v e old) (df_eval_backprop_deriv σ df grad_env grad)
- | None => None
- end.
-
-(*
- Program Definition df_eval_backward_gen_top {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v: var_type) (grad_env:df_env) :
- option (lifted_type (definition_function_types_interp (snd v)) T) :=
- (match T as vt return T = vt -> option (lifted_type (definition_function_types_interp (snd v)) vt) with
- | DTfloat => fun pf => df_eval_backprop_delta σ df v grad_env (coerce _ 1)
- | DTVector n => fun pf =>
- vectoro_to_ovector
- (fun i => df_eval_backprop_delta σ df v grad_env (coerce _ (UnitVector n i)))
- | DTMatrix m n => fun pf =>
- matrixo_to_omatrix
- (fun i j => df_eval_backprop_delta σ df v grad_env (coerce _ (UnitMatrix m n i j)))
- end) (eq_refl _).
- *)
-
- Program Definition df_eval_backward_gen_top {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v: var_type) (grad_env:df_env) :
- option (lifted_type (definition_function_types_interp (snd v)) T) :=
- match vartlookup grad_env v with
- | Some old =>
- (match T as vt return T = vt -> option (lifted_type (definition_function_types_interp (snd v)) vt) with
- | DTfloat => fun pf => (lift (fun e => subvar v e old) (df_eval_backprop_deriv σ df grad_env (coerce _ 1)))
- | DTVector n => fun pf =>
- vectoro_to_ovector
- (fun i => lift (fun e => subvar v e old) (df_eval_backprop_deriv σ df grad_env (coerce _ (UnitVector n i))))
- | DTMatrix m n => fun pf =>
- matrixo_to_omatrix
- (fun i j => lift (fun e => subvar v e old) (df_eval_backprop_deriv σ df grad_env (coerce _ (UnitMatrix m n i j))))
- end) (eq_refl _)
- | None => None
- end.
-
- Definition transpose_lifted_type {T1 T2} :
- lifted_type (definition_function_types_interp T1) T2 ->
- lifted_type (definition_function_types_interp T2) T1
- := match T1, T2 with
- | DTfloat, _ => fun inp => inp
- | _, DTfloat => fun inp => inp
- | DTVector n1, DTVector n2 => fun inp => fun i j => inp j i
- | DTMatrix m1 n1, DTMatrix m2 n2 => fun inp => fun i j p q => inp p q i j
- | DTVector n1, DTMatrix m2 n2 => fun inp => fun i p q => inp p q i
- | DTMatrix m1 n1, DTVector n2 => fun inp => fun i j p => inp p i j
- end.
- Section deriv_deriv.
- End deriv_deriv.
-
- Section max_derived.
- Definition MaxDerived (a b : DefinedFunction UnitAnn DTfloat) :=
- Divide tt (Plus tt (Plus tt (Abs tt (Minus tt b a)) b) a) (Number tt 2).
-
- Delimit Scope df_scope with df.
-
- Notation "x + y" := (Plus x y) (only printing) : df_scope.
- Notation "x - y" := (Minus x y) (only printing) : df_scope.
- Notation "x / y" := (Divide x y) (only printing) : df_scope.
- Notation "x * y" := (Times x y) (only printing) : df_scope.
- Notation "x" := (Number x) (only printing, at level 0) : df_scope.
- Notation "x" := (Var x) (only printing, at level 0) : df_scope.
- Notation "'|' x '|'" := (Abs x) (only printing, at level 0) : df_scope.
-
- End max_derived.
-
- Section fv.
-
- Fixpoint df_free_variables {Ann} {T} (f : DefinedFunction Ann T) : list var_type
- := match f with
- | Number _ x => nil
- | DVector n _ x => vlconcat_map df_free_variables x
- | Constant t _ x => nil
- | DMatrix n m _ x => vlconcat_map (fun a => vlconcat_map df_free_variables a) x
- | Var v _ => v::nil
- | Plus _ l r => (df_free_variables l) ++ (df_free_variables r)
- | Minus _ l r => (df_free_variables l) ++ (df_free_variables r)
- | Times _ l r => (df_free_variables l) ++ (df_free_variables r)
- | Divide _ l r => (df_free_variables l) ++ (df_free_variables r)
- | Max _ l r => (df_free_variables l) ++ (df_free_variables r)
- | Abs _ e => df_free_variables e
- | Sign _ e => df_free_variables e
- | PSign _ e => df_free_variables e
- | Log _ e => df_free_variables e
- | Square _ e => df_free_variables e
- | Exp _ e => df_free_variables e
- | VectorElem n _ l i => df_free_variables l
- | MatrixElem m n _ l i j => df_free_variables l
- | VectorDot n _ l r => (df_free_variables l) ++ (df_free_variables r)
- | VectorSum n _ l => df_free_variables l
- | MatrixSum n m _ l => df_free_variables l
- | VectorScalMult n _ x r => (df_free_variables x) ++ (df_free_variables r)
- | MatrixScalMult n m _ x r => (df_free_variables x) ++ (df_free_variables r)
- | MatrixVectorMult n m _ l r => (df_free_variables l) ++ (df_free_variables r)
- | MatrixVectorAdd n m _ l r => (df_free_variables l) ++ (df_free_variables r)
- | MatrixMult n m p _ l r => (df_free_variables l) ++ (df_free_variables r)
- | VectorPlus n _ l r => (df_free_variables l) ++ (df_free_variables r)
- | VectorMinus n _ l r => (df_free_variables l) ++ (df_free_variables r)
- | MatrixPlus n m _ l r => (df_free_variables l) ++ (df_free_variables r)
- | MatrixMinus n m _ l r => (df_free_variables l) ++ (df_free_variables r)
- | VectorApply n _ x s l => (remove_all (x,DTfloat) (df_free_variables s))
- ++ (df_free_variables l)
- | MatrixApply n m _ x s l => (remove_all (x,DTfloat) (df_free_variables s))
- ++ (df_free_variables l)
- | VLossfun n _ v1 v2 s l r => (remove_all (v1,DTfloat) (remove_all (v2,DTfloat) (df_free_variables s)))
- ++ (df_free_variables l)
- | MLossfun n m _ v1 v2 s l r => (remove_all (v1,DTfloat) (remove_all (v2,DTfloat) (df_free_variables s)))
- ++ (df_free_variables l)
- end.
-
- Definition df_closed {Ann} {T} (f: DefinedFunction Ann T) : Prop
- := match df_free_variables f with
- | nil => True
- | _ => False
- end.
-
- Lemma df_closed_nil {T} (f: DefinedFunction UnitAnn T) : df_closed f -> df_free_variables f = nil.
- Proof.
- unfold df_closed.
- destruct (df_free_variables f); tauto.
- Qed.
-
- Definition df_closed_over {Ann} {T} (f : DefinedFunction Ann T) (vl : list var_type) : Prop
- := incl (df_free_variables f) vl.
-
- Fixpoint fully_closed_over {Ann} {T} (df : DefinedFunction Ann T) (vl : list var_type) : Prop
- :=
- match df with
- | Number _ x => True
- | DVector n _ x => vforall (fun f => fully_closed_over f vl) x
- | Constant t _ x => True
- | DMatrix n m _ x => vforall (fun row =>
- (vforall (fun f => fully_closed_over f vl) row)) x
- | Var v _ => In v vl
- | Plus _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | Minus _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | Times _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | Divide _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | Max _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | Abs _ e => fully_closed_over e vl
- | Sign _ e => fully_closed_over e vl
- | PSign _ e => fully_closed_over e vl
- | Log _ e => fully_closed_over e vl
- | Square _ e => fully_closed_over e vl
- | Exp _ e => fully_closed_over e vl
- | VectorElem n _ l i => fully_closed_over l vl
- | MatrixElem m n _ l i j => fully_closed_over l vl
- | VectorDot n _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | VectorSum n _ l => fully_closed_over l vl
- | MatrixSum n m _ l => fully_closed_over l vl
- | VectorScalMult n _ x r => (fully_closed_over x vl) /\ (fully_closed_over r vl)
- | MatrixScalMult n m _ x r => (fully_closed_over x vl) /\ (fully_closed_over r vl)
- | MatrixVectorMult n m _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | MatrixVectorAdd n m _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | MatrixMult n m p _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | VectorPlus n _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | VectorMinus n _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | MatrixPlus n m _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | MatrixMinus n m _ l r => (fully_closed_over l vl) /\ (fully_closed_over r vl)
- | VectorApply n _ x s l => (fully_closed_over s ((x,DTfloat)::nil)) /\
- (fully_closed_over l vl)
- | MatrixApply n m _ x s l => (fully_closed_over s ((x,DTfloat)::nil)) /\
- (fully_closed_over l vl)
- | VLossfun n _ v1 v2 s l r => (fully_closed_over s ((v1,DTfloat)::(v2,DTfloat)::nil))
- /\ (fully_closed_over l vl)
- | MLossfun n m _ v1 v2 s l r => (fully_closed_over s ((v1,DTfloat)::(v2,DTfloat)::nil))
- /\ (fully_closed_over l vl)
- end.
-
- Definition In_compat_map (f : list var_type -> list var_type) : Prop :=
- forall (v : var_type) (vl : list var_type),
- In v vl -> In v (f vl).
-
- Definition map_tl (f : list var_type -> list var_type) (vl : list var_type) :=
- match vl with
- | a :: vl1 => a :: f vl1
- | _ => f vl
- end.
-
- Lemma In_compat_map_tl (f : list var_type -> list var_type) :
- In_compat_map f -> In_compat_map (map_tl f).
- Proof.
- unfold In_compat_map; intros.
- destruct vl.
- + now simpl.
- + simpl in *.
- destruct H0.
- * now left.
- * right; now apply H.
- Qed.
-
- Lemma fully_closed_over_map {T} (df : DefinedFunction UnitAnn T) (vl : list var_type) (f : list var_type -> list var_type) :
- In_compat_map f -> fully_closed_over df vl -> fully_closed_over df (f vl).
- Proof.
- revert f; revert vl.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case
- ; simpl; intros; try solve [
- trivial
- |
- apply IHdf; trivial
- |
- split; destruct H0;
- [apply IHdf1; trivial
- | apply IHdf2; trivial]
- ].
- - Case "DVector"%string.
- apply vforall_forall; intros.
- apply H; trivial.
- now rewrite vforall_forall in H1.
- - Case "DMatrix"%string.
- apply vforall_forall; intros.
- apply vforall_forall; intros.
- apply H; trivial.
- rewrite vforall_forall in H1.
- specialize (H1 i).
- now rewrite vforall_forall in H1.
- - now apply H.
- - Case "VectorApply"%string.
- split; destruct H0; trivial.
- now apply IHdf2.
- - Case "MatrixApply"%string.
- split; destruct H0; trivial.
- now apply IHdf2.
- - Case "VLossfun"%string.
- split; destruct H0; trivial.
- now apply IHdf2.
- - Case "MLossfun"%string.
- split; destruct H0; trivial.
- now apply IHdf2.
- Qed.
-
- (*
- Lemma closed_is_fully_closed {Ann} {T} (df : DefinedFunction Ann T) (vl : list var_type) :
- df_closed_over df vl <-> fully_closed_over df vl.
-*)
-
-
-(*
- Lemma df_subst_nfree {T} (e: DefinedFunction T) (v:SubVar) (e':DefinedFunction DTfloat) :
- ~ In v (df_free_variables e) ->
- df_subst e v e' = e.
- Proof.
- DefinedFunction_cases (induction e) Case; simpl; trivial; intros nin
- ; try solve [try rewrite in_app_iff in nin
- ; intuition congruence].
- - Case "DVector"%string.
- f_equal.
- apply functional_extensionality.
- intros x0.
- apply H.
- intros inn.
- apply nin.
- unfold vlconcat_map, vlconcat.
- apply concat_In.
- exists ((df_free_variables (x x0))).
- split; trivial.
- apply vector_to_list_In.
-
- - Case "DMatrix"%string.
-
- - Case "Var"%string.
- destruct (var_dec v0 v); intuition.
- Qed.
-
- Lemma df_eval_complete' {T} (σ:df_env) (f:DefinedFunction T) :
- incl (df_free_variables f) (domain σ) -> {v | df_eval σ f = Some v}.
- Proof.
- induction f; simpl; intros inc
- ; try solve [rewrite <- incl_app_iff in inc
- ; intuition
- ; destruct X as [v1 ev1]
- ; destruct X0 as [v2 ev2]
- ; rewrite ev1; rewrite ev2
- ; eauto
- | intuition
- ; destruct X as [v1 ev1]
- ; rewrite ev1
- ; eauto].
- - eauto.
- - apply in_dom_lookup_strong.
- specialize (inc v); simpl in *.
- intuition.
- Qed.
-
- (* This version has better computational properties *)
- Lemma df_eval_complete (σ:df_env) (f:DefinedFunction) :
- incl (df_free_variables f) (domain σ) -> {v | df_eval σ f = Some v}.
- Proof.
- case_eq (df_eval σ f); simpl.
- - intros r ?? ; exists r; eauto.
- - intros ? inc.
- destruct (df_eval_complete' _ _ inc); congruence.
- Defined.
-
- Lemma df_eval_none (σ:df_env) (f:DefinedFunction) :
- df_eval σ f = None ->
- {v | In v (df_free_variables f) /\ ~ In v (domain σ)}.
- Proof.
- intros.
- destruct (incl_dec (df_free_variables f) (domain σ)).
- - destruct (df_eval_complete _ _ i); congruence.
- - apply (nincl_exists) in n; trivial.
- Qed.
-
- (* Either we can evaluate df or we are missing a variable definition.
- Note that this theorem may fail to hold if we change the definition of
- division to make it partial.
- *)
- Lemma df_eval_compute (σ:df_env) (f:DefinedFunction) :
- {v | df_eval σ f = Some v} + {x | In x (df_free_variables f) /\ ~ In x (domain σ)}.
- Proof.
- case_eq (df_eval σ f); simpl.
- - eauto.
- - intros H; apply df_eval_none in H; eauto.
- Defined.
-
- Lemma df_eval_closed (f:DefinedFunction) :
- df_closed f -> {v | df_eval nil f = Some v}.
- Proof.
- intros c.
- apply (df_eval_complete nil f).
- rewrite df_closed_nil by trivial.
- simpl; reflexivity.
- Defined.
-
- Lemma df_eval_lookup_on (σ₁ σ₂:df_env) (f:DefinedFunction) :
- lookup_equiv_on (df_free_variables f) σ₁ σ₂ ->
- df_eval σ₁ f = df_eval σ₂ f.
- Proof.
- intros lookeq.
- induction f; simpl in *; trivial
- ; try solve [apply lookup_equiv_on_dom_app in lookeq; intuition
- ; rewrite H1, H2; trivial
- | rewrite IHf; trivial].
- - apply lookeq; simpl; tauto.
- Qed.
-*)
- End fv.
-
- Section apply.
-
- Fixpoint df_apply {T} (e: DefinedFunction UnitAnn T)
- (args: forall (v:var_type), DefinedFunction UnitAnn (snd v)) : DefinedFunction UnitAnn T :=
- match e with
- | Number _ x => Number tt x
- | Constant t _ x => Constant tt x
- | DVector n _ df => DVector tt (fun x => df_apply (df x) args)
- | DMatrix n m _ df => DMatrix tt (fun i j => df_apply (df i j) args)
- | Var v _ => args v
- | Plus _ l r => Plus tt (df_apply l args) (df_apply r args)
- | Times _ l r => Times tt (df_apply l args) (df_apply r args)
- | Minus _ l r => Minus tt (df_apply l args) (df_apply r args)
- | Divide _ l r => Divide tt (df_apply l args) (df_apply r args)
- | Square _ e => Square tt (df_apply e args)
- | Exp _ e => Exp tt (df_apply e args)
- | Log _ e => Log tt (df_apply e args)
- | Abs _ e => Abs tt (df_apply e args)
- | Sign _ e => Sign tt (df_apply e args)
- | PSign _ e => PSign tt (df_apply e args)
- | Max _ l r => Max tt (df_apply l args) (df_apply r args)
- | VectorElem n _ l i => VectorElem tt (df_apply l args) i
- | MatrixElem m n _ l i j => MatrixElem tt (df_apply l args) i j
- | VectorDot n _ l r => VectorDot tt (df_apply l args) (df_apply r args)
- | VectorSum n _ l => VectorSum tt (df_apply l args)
- | MatrixSum n m _ l => MatrixSum tt (df_apply l args)
- | VectorScalMult n _ x r => VectorScalMult tt (df_apply x args) (df_apply r args)
- | MatrixScalMult n m _ x r => MatrixScalMult tt (df_apply x args) (df_apply r args)
- | MatrixVectorMult n m _ l r => MatrixVectorMult tt (df_apply l args) (df_apply r args)
- | MatrixVectorAdd n m _ l r => MatrixVectorAdd tt (df_apply l args) (df_apply r args)
- | MatrixMult n m p _ l r => MatrixMult tt (df_apply l args) (df_apply r args)
- | VectorPlus n _ l r => VectorPlus tt (df_apply l args) (df_apply r args)
- | VectorMinus n _ l r => VectorMinus tt (df_apply l args) (df_apply r args)
- | MatrixPlus n m _ l r => MatrixPlus tt (df_apply l args) (df_apply r args)
- | MatrixMinus n m _ l r => MatrixMinus tt (df_apply l args) (df_apply r args)
- | VectorApply n _ x s l => VectorApply tt x (df_apply s args) (df_apply l args)
- | MatrixApply n m _ x s l => MatrixApply tt x (df_apply s args) (df_apply l args)
- | VLossfun n _ v1 v2 s l r => VLossfun tt v1 v2 (df_apply s args) (df_apply l args) r
- | MLossfun n m _ v1 v2 s l r => MLossfun tt v1 v2 (df_apply s args) (df_apply l args) r
- end.
-
- End apply.
-
-End DefinedFunctions.
-
-Tactic Notation "DefinedFunction_cases" tactic(first) ident(c) :=
- first;
- [ Case_aux c "Number"%string
- | Case_aux c "Constant"%string
- | Case_aux c "DVector"%string
- | Case_aux c "DMatrix"%string
- | Case_aux c "Var"%string
- | Case_aux c "Plus"%string
- | Case_aux c "Minus"%string
- | Case_aux c "Times"%string
- | Case_aux c "Divide"%string
- | Case_aux c "Square"%string
- | Case_aux c "Exp"%string
- | Case_aux c "Log"%string
- | Case_aux c "Abs"%string
- | Case_aux c "Sign"%string
- | Case_aux c "PSign"%string
- | Case_aux c "Max"%string
- | Case_aux c "VectorDot"%string
- | Case_aux c "VectorSum"%string
- | Case_aux c "MatrixSum"%string
- | Case_aux c "VectorElem"%string
- | Case_aux c "MatrixElem"%string
- | Case_aux c "MatrixVectorMult"%string
- | Case_aux c "MatrixVectorAdd"%string
- | Case_aux c "MatrixMult"%string
- | Case_aux c "VectorPlus"%string
- | Case_aux c "VectorMinus"%string
- | Case_aux c "MatrixPlus"%string
- | Case_aux c "MatrixMinus"%string
- | Case_aux c "VectorScalMult"%string
- | Case_aux c "MatrixScalMult"%string
- | Case_aux c "VectorApply"%string
- | Case_aux c "MatrixApply"%string
- | Case_aux c "VLossfun"%string
- | Case_aux c "MLossfun"%string].
-
-Ltac refl_simpler :=
- repeat
- match goal with
- | [H: @eq var_type _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H)
- | [H: @equiv var_type _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H)
- | [H: @eq definition_function_types _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H)
- | [H: @equiv definition_function_types _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H)
- end.
-
-Section real_pfs.
-
- Local Existing Instance floatish_R.
- Import Reals.
- Import List.
-
- Lemma MaxDerivedMax_eq (a b : DefinedFunction UnitAnn DTfloat) :
- forall σ, df_eval σ (Max tt a b) = df_eval σ (MaxDerived a b).
- Proof.
- simpl; intros σ.
- destruct (df_eval σ a); destruct (df_eval σ b); trivial.
- f_equal.
- autorewrite with Rarith in *.
- destruct (Rle_dec d d0).
- - rewrite Rmax_right by trivial.
- rewrite Rabs_pos_eq by lra.
- lra.
- - rewrite Rmax_left by lra.
- rewrite Rabs_minus_sym.
- rewrite Rabs_pos_eq by lra.
- lra.
- Qed.
-
-(* Lemma coerce_dec_id {A} (dec:forall x y:A, {x=y}+{x<>y}) (x:A) (pf:x=x) : coerce pf x = x.
- Proof.
- unfold coerce.
- replace pf with (eq_refl A); trivial.
- apply UIP_dec.
- apply dec.
- generalize (@UIP_dec A dec pf).
- Lemma var_type_UIP_refl {x:var_type} (e:x=x) : e = eq_refl x.
- Proof.
- apply (UIP_dec vart_dec x x pf).
- Qed.
-
- unfold coerce.
- destruct pf.
- destruct pf.
- exact a.
- Defined.
-*)
-
-Tactic Notation "DefinedFunction_scalar_cases" tactic(first) ident(c) :=
- first;
- [ Case_aux c "Number"%string
- | Case_aux c "Constant"%string
- | Case_aux c "Var"%string
- | Case_aux c "Plus"%string
- | Case_aux c "Minus"%string
- | Case_aux c "Times"%string
- | Case_aux c "Divide"%string
- | Case_aux c "Square"%string
- | Case_aux c "Exp"%string
- | Case_aux c "Log"%string
- | Case_aux c "Abs"%string
- | Case_aux c "Sign"%string
- | Case_aux c "PSign"%string
- | Case_aux c "Max"%string].
-
-
- Lemma backpropeq_gen (x : SubVar) (env gradenv : df_env) (dfexpr : DefinedFunction UnitAnn DTfloat) (grad : float) :
- let xvar := (x, DTfloat) in
- is_scalar_function dfexpr ->
- vartlookup gradenv (x,DTfloat) <> None ->
- match df_eval_deriv env dfexpr xvar,
- backprop_lookup (Some gradenv) xvar,
- backprop_lookup (df_eval_backprop_deriv env dfexpr gradenv grad) xvar
- with
- | Some dval, Some bval0, Some bval1 => (dval*grad + bval0)%R = bval1
- | None, _, None => True
- | _, _, _ => False
- end.
- Proof.
- simpl.
- intros is_scalar.
- generalize is_scalar.
- revert grad gradenv.
- pattern dfexpr.
- revert dfexpr is_scalar.
- DefinedFunction_scalar_cases (apply is_scalar_function_ind) Case; simpl.
- - Case "Number"%string.
- intros _ _ grad gradenv _ xinn.
- destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra.
- - Case "Constant"%string.
- intros _ _ grad gradenv _ xinn.
- destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra.
- - Case "Var"%string.
- intros sv _ grad gradenv _ xinn.
- case_eq (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto].
- destruct (var_dec x sv); simpl.
- + subst.
- rewrite H; simpl.
- rewrite lookup_update.
- destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (sv, DTfloat)); [| congruence].
- unfold addvar; simpl.
- rewrite H.
- lra.
- + destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (x, DTfloat)); [congruence | ].
- case_eq (vartlookup gradenv (sv, DTfloat)); simpl; intros.
- * rewrite lookup_update_neq by congruence.
- rewrite H.
- lra.
- * rewrite H.
- lra.
- - Case "Plus"%string.
- intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn.
- case_eq (df_eval_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl grad gradenv isc1 xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_backprop_deriv env l gradenv grad)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr grad ge' isc2).
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_backprop_deriv env r ge' grad) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_backprop_deriv env l gradenv grad ); simpl; trivial; intros.
- apply IHr; trivial.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl grad gradenv isc1 xinn).
- case_eq (df_eval_backprop_deriv env l gradenv grad); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Minus"%string.
- intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn.
- case_eq (df_eval_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl grad gradenv isc1 xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_backprop_deriv env l gradenv grad)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr (- grad)%R ge' isc2).
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_backprop_deriv env r ge' (- grad)%R) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_backprop_deriv env l gradenv grad ); simpl; trivial; intros.
- apply IHr; trivial.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl grad gradenv isc1 xinn).
- case_eq (df_eval_backprop_deriv env l gradenv grad); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Times"%string.
- intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn.
- case_eq (df_eval env l);
- [ intros le eqle | intros eqle]; simpl; trivial.
- case_eq (df_eval_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval env r);
- [ intros re eqre | intros eqre]
- ; simpl; trivial.
- case_eq (df_eval_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl (re * grad)%R gradenv isc1 xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_backprop_deriv env l gradenv (re * grad)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr (le * grad)%R ge' isc2).
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_backprop_deriv env r ge' (le * grad)%R) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_backprop_deriv env l gradenv (re * grad)%R ); simpl; trivial; intros.
- apply IHr; trivial.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + case_eq (df_eval env r);
- [ intros re eqre | intros eqre]
- ; simpl; trivial.
- specialize (IHl (re * grad)%R gradenv isc1 xinn).
- case_eq (df_eval_backprop_deriv env l gradenv (re * grad)%R); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Divide"%string.
- intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn.
- case_eq (df_eval env l);
- [ intros le eqle | intros eqle]; simpl; trivial.
- case_eq (df_eval_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval env r);
- [ intros re eqre | intros eqre]
- ; simpl; trivial.
- case_eq (df_eval_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl (grad / re)%R gradenv isc1 xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_backprop_deriv env l gradenv (grad / re)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr (- le / (re * re) * grad)%R ge' isc2).
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_backprop_deriv env r ge' (- le / (re * re) * grad)%R) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_backprop_deriv env l gradenv (grad / re)%R ); simpl; trivial; intros.
- apply IHr; trivial.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + case_eq (df_eval env r);
- [ intros re eqre | intros eqre]
- ; simpl; trivial.
- specialize (IHl (grad / re)%R gradenv isc1 xinn).
- case_eq (df_eval_backprop_deriv env l gradenv (grad / re)%R); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Square"%string.
- intros _ e IHe grad gradenv isc xinn.
- case_eq (df_eval env e);
- [ intros le eqee | intros eqee]; simpl; trivial.
-
- specialize (IHe (2 * le * grad)%R gradenv isc xinn).
- case_eq (df_eval_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
-
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_backprop_deriv env e gradenv (2 * le * grad)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
- - Case "Exp"%string.
- intros _ e IHe grad gradenv isc xinn.
- case_eq (df_eval env e);
- [ intros le eqee | intros eqee]; simpl; trivial.
-
- specialize (IHe (grad * exp le)%R gradenv isc xinn).
- case_eq (df_eval_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
-
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_backprop_deriv env e gradenv (grad * exp le)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
- - Case "Log"%string.
- intros _ e IHe grad gradenv isc xinn.
- case_eq (df_eval env e);
- [ intros le eqee | intros eqee]; simpl; trivial.
-
- specialize (IHe (grad / le)%R gradenv isc xinn).
- case_eq (df_eval_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
-
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_backprop_deriv env e gradenv (grad / le)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
- - Case "Abs"%string.
- intros _ e IHe grad gradenv isc xinn.
- case_eq (df_eval env e);
- [ intros le eqee | intros eqee]; simpl; trivial.
-
- specialize (IHe (grad * (sign le))%R gradenv isc xinn).
- case_eq (df_eval_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
-
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_backprop_deriv env e gradenv (grad * (sign le))%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
- - Case "Sign"%string.
- intros _ e IHe grad gradenv isc xinn.
- specialize (IHe 0%R gradenv isc xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (df_eval_backprop_deriv env e gradenv 0%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- simpl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- lra.
- - Case "PSign"%string.
- intros _ e IHe grad gradenv isc xinn.
- specialize (IHe 0%R gradenv isc xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (df_eval_backprop_deriv env e gradenv 0%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- simpl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- lra.
- - Case "Max"%string.
- intros _ l r IHl IHr grad gradenv [isc1 isc2] xinn.
- specialize (IHl grad gradenv isc1 xinn).
- specialize (IHr grad gradenv isc2 xinn).
-
- case_eq (df_eval env l); simpl; trivial
- ; intros eld eqeld.
- case_eq (df_eval env r ); simpl; intros; trivial.
- destruct (Rle_dec eld d); simpl.
- + destruct (df_eval_deriv env r (x, DTfloat)); simpl; trivial.
- + destruct (df_eval_deriv env l (x, DTfloat)); simpl; trivial.
- Qed.
-
- (*
-
- Lemma tree_backpropeq_gen (x : SubVar) (env gradenv : df_env)
- (dfexpr : DefinedFunction EvalAnn DTfloat) (grad : float) :
- let xvar := (x, DTfloat) in
- is_scalar_function dfexpr ->
- vartlookup gradenv (x,DTfloat) <> None ->
- match df_eval_tree_deriv env dfexpr xvar,
- backprop_lookup (Some gradenv) xvar,
- backprop_lookup (df_eval_tree_backprop_deriv env dfexpr gradenv grad) xvar
- with
- | Some dval, Some bval0, Some bval1 => (dval*grad + bval0)%R = bval1
- | None, _, None => True
- | _, _, _ => False
- end.
- Proof.
- simpl.
- intros is_scalar.
- revert grad gradenv.
- pattern dfexpr.
- revert dfexpr is_scalar.
- DefinedFunction_scalar_cases (apply is_scalar_function_ind) Case; simpl.
-
- - Case "Number"%string.
- intros _ _ grad gradenv xinn.
- destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra.
- - Case "Constant"%string.
- intros _ _ grad gradenv xinn.
- destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra.
- - Case "Var"%string.
- intros sv _ grad gradenv xinn.
- case_eq (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto].
- destruct (var_dec x sv); simpl.
- + subst.
- rewrite H; simpl.
- rewrite lookup_update.
- destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (sv, DTfloat)); [| congruence].
- unfold addvar; simpl.
- rewrite H.
- lra.
- + destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (x, DTfloat)); [congruence | ].
- case_eq (vartlookup gradenv (sv, DTfloat)); simpl; intros.
- * rewrite lookup_update_neq by congruence.
- rewrite H.
- lra.
- * rewrite H.
- lra.
- - Case "Plus"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- case_eq (df_eval_tree_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_tree_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl grad gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env l gradenv grad)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr grad ge').
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' grad) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_tree_backprop_deriv env l gradenv grad ); simpl; trivial; intros.
- apply IHr.
- apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl grad gradenv xinn).
- case_eq (df_eval_tree_backprop_deriv env l gradenv grad); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Minus"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- case_eq (df_eval_tree_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_tree_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl grad gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env l gradenv grad)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr (- grad)%R ge').
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (- grad)%R) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_tree_backprop_deriv env l gradenv grad ); simpl; trivial; intros.
- apply IHr.
- apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl grad gradenv xinn).
- case_eq (df_eval_tree_backprop_deriv env l gradenv grad); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Times"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- case_eq (df_eval_tree_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_tree_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl (get_annotation r * grad)%R gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr (get_annotation l * grad)%R ge').
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (get_annotation l * grad)%R) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R ); simpl; trivial; intros.
- apply IHr.
- apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl (get_annotation r * grad)%R gradenv xinn).
- case_eq (df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Divide"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- case_eq (df_eval_tree_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_tree_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl (grad / get_annotation r)%R gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr (- get_annotation l / ((get_annotation r) * (get_annotation r)) * grad)%R ge').
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (- get_annotation l / ((get_annotation r) * (get_annotation r)) * grad)%R) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r)%R ); simpl; trivial; intros.
- apply IHr.
- apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl (grad / get_annotation r)%R gradenv xinn).
- case_eq (df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r )%R); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Square"%string.
- intros _ e IHe grad gradenv xinn.
-
- specialize (IHe (2 * (get_annotation e) * grad)%R gradenv xinn).
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
-
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env e gradenv (2 * (get_annotation e) * grad)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
- - Case "Exp"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe (grad * exp (get_annotation e))%R gradenv xinn).
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env e gradenv (grad * exp (get_annotation e))%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
-
- - Case "Log"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe (grad / get_annotation e)%R gradenv xinn).
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env e gradenv (grad / get_annotation e)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
-
- - Case "Abs"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe (grad * (sign (get_annotation e)))%R gradenv xinn).
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env e gradenv (grad * (sign (get_annotation e)))%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
- - Case "Sign"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe 0%R gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (df_eval_tree_backprop_deriv env e gradenv 0%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- simpl.
- replace (de * 0)%R with (0)%R in IHe by lra.
- replace (0 * grad)%R with (0)%R by lra.
- apply IHe.
- - Case "PSign"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe 0%R gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (df_eval_tree_backprop_deriv env e gradenv 0%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- simpl.
- replace (de * 0)%R with (0)%R in IHe by lra.
- replace (0 * grad)%R with (0)%R by lra.
- apply IHe.
- - Case "Max"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- specialize (IHl grad gradenv xinn).
- specialize (IHr grad gradenv xinn).
- destruct (Rle_dec (get_annotation l) (get_annotation r)); simpl.
- destruct (df_eval_tree_deriv env r (x, DTfloat)); simpl; trivial.
- destruct (df_eval_tree_deriv env l (x, DTfloat)); simpl; trivial.
- Qed.
- *)
-
- Lemma eval_fully_closed_not_none {T} (σ:df_env) (df:DefinedFunction UnitAnn T) :
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl -> df_eval σ df <> None.
- Proof.
- revert σ.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case
- ; simpl; intros;
- try solve
- [congruence
- |
- destruct H; simpl in *;
- specialize (IHdf1 σ); specialize (IHdf2 σ);
- match_option; [|tauto];
- cut_to IHdf2;
- [ match_option; tauto | easy]
- |
- specialize (IHdf σ); simpl in IHdf;
- cut_to IHdf;
- [ match_option; tauto | easy]
- ].
- - Case "DVector"%string.
- apply vectoro_to_ovector_not_none; intro.
- specialize (H i σ); simpl in H; apply H.
- rewrite vforall_forall in H0.
- now specialize (H0 i).
- - Case "DMatrix"%string.
- unfold matrixo_to_omatrix.
- apply vectoro_to_ovector_not_none; intro.
- apply vectoro_to_ovector_not_none; intro.
- specialize (H i i0 σ); simpl in H; apply H.
- rewrite vforall_forall in H0; specialize (H0 i).
- rewrite vforall_forall in H0; now specialize (H0 i0).
- - Case "Var"%string.
- induction σ.
- + simpl in H; tauto.
- + simpl in *.
- match_case; intros.
- destruct H.
- * congruence.
- * now apply IHσ.
- - Case "VectorApply"%string.
- destruct H; simpl in *.
- specialize (IHdf2 σ).
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- apply vectoro_to_ovector_not_none; intro.
- specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i) :: nil)).
- now apply IHdf1.
- - Case "MatrixApply"%string.
- destruct H; simpl in *.
- specialize (IHdf2 σ).
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- unfold matrixo_to_omatrix.
- apply vectoro_to_ovector_not_none; intro.
- apply vectoro_to_ovector_not_none; intro.
- specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i i0) :: nil)).
- now apply IHdf1.
- - Case "VLossfun"%string.
- destruct H; simpl in *.
- specialize (IHdf2 σ).
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- match_option.
- apply vectoro_to_ovector_not_none in eqq0.
- + tauto.
- + intros.
- specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i) :: mk_env_entry (v2, DTfloat) (r i) :: nil)).
- now apply IHdf1.
- - Case "MLossfun"%string.
- destruct H; simpl in *.
- specialize (IHdf2 σ).
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- unfold matrixo_to_omatrix.
- match_option.
- apply vectoro_to_ovector_not_none in eqq0.
- + tauto.
- + intros.
- apply vectoro_to_ovector_not_none; intros.
- specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i i0) :: mk_env_entry (v2, DTfloat) (r i i0) :: nil)).
- now apply IHdf1.
- Qed.
-
- Lemma eval_fully_closed_total {T} (σ:df_env) (df:DefinedFunction UnitAnn T) :
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl ->
- {d:definition_function_types_interp T | df_eval σ df = Some d}.
- Proof.
- intros.
- case_eq (df_eval σ df); intros.
- - now exists d.
- - generalize (eval_fully_closed_not_none σ df).
- intros; simpl in *.
- cut_to H1; tauto.
- Qed.
-
- Lemma closed_over_cons {T} (df:DefinedFunction UnitAnn T) (v:var_type) (vl : list var_type):
- df_closed_over df vl -> df_closed_over df (v::vl).
- Proof.
- unfold df_closed_over.
- intros.
- apply incl_tl.
- apply H.
- Qed.
-
- Lemma fully_closed_over_cons {T} (df:DefinedFunction UnitAnn T) (v:var_type)
- (vl : list var_type):
- fully_closed_over df vl -> fully_closed_over df (v::vl).
- Proof.
- intros.
- apply (fully_closed_over_map df vl (fun vl1 => cons v vl1)); trivial.
- unfold In_compat_map.
- intros.
- now apply in_cons.
- Qed.
-
- Lemma fully_closed_over_exchange_vars {T} (df:DefinedFunction UnitAnn T) (v1 v:var_type)
- (vl : list var_type):
- fully_closed_over df (v1 :: v :: vl) -> fully_closed_over df (v :: v1 :: vl).
- Proof.
- intros.
- apply (fully_closed_over_map df (v1 :: v :: vl)
- (fun vl1 => match vl1 with
- | a :: b :: vl2 => b :: a :: vl2
- | _ => vl1
- end )); trivial.
- unfold In_compat_map.
- intros.
- destruct vl0; trivial.
- destruct vl0; trivial.
- unfold In.
- unfold In in H0.
- tauto.
- Qed.
-
- Lemma fully_closed_over_singleton {T} (df:DefinedFunction UnitAnn T) (v:var_type)
- (vl : list var_type):
- fully_closed_over df (v::nil) -> fully_closed_over df (v::vl).
- Proof.
- intros.
- induction vl; trivial.
- apply fully_closed_over_exchange_vars.
- now apply fully_closed_over_cons.
- Qed.
-
- Lemma fully_closed_over_exchange_2vars {T} (df:DefinedFunction UnitAnn T)
- (v1 v2 v:var_type) (vl : list var_type):
- fully_closed_over df (v1 :: v2 :: v:: vl) -> fully_closed_over df (v :: v1 :: v2 :: vl).
- Proof.
- intros.
- apply (fully_closed_over_map df (v1 :: v2 :: v :: vl)
- (fun vl1 => match vl1 with
- | a :: b :: c :: vl2 => c :: a :: b :: vl2
- | _ => vl1
- end )); trivial.
- unfold In_compat_map.
- intros.
- destruct vl0; trivial.
- destruct vl0; trivial.
- destruct vl0; trivial.
- unfold In.
- unfold In in H0.
- tauto.
- Qed.
-
- Lemma fully_closed_over_pair {T} (df:DefinedFunction UnitAnn T) (v1 v2:var_type)
- (vl : list var_type):
- fully_closed_over df (v1::v2::nil) -> fully_closed_over df (v1::v2::vl).
- Proof.
- intros.
- induction vl; trivial.
- apply fully_closed_over_exchange_2vars.
- apply fully_closed_over_exchange_2vars.
- now apply fully_closed_over_cons.
- Qed.
-
- Lemma fully_closed_subst {T} (vl:list var_type) (df:DefinedFunction UnitAnn T) (v:var_type)
- (e':DefinedFunction UnitAnn (snd v)):
- fully_closed_over df (v::vl) ->
- fully_closed_over e' vl ->
- fully_closed_over (df_subst df v e') vl.
- Proof.
- revert vl.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case
- ; simpl; intros; try solve
- [easy
- |
- apply IHdf; [apply H | apply H0]
- |
- split; destruct H; simpl in *;
- [ apply IHdf1; [apply H | apply H0]
- | apply IHdf2; [apply H1 | apply H0]]
- ].
- - Case "DVector"%string.
- apply vforall_forall; intros; simpl in *.
- specialize (H i); apply H.
- rewrite vforall_forall in H0.
- specialize (H0 i).
- apply H0.
- apply H1.
- - Case "DMatrix"%string.
- apply vforall_forall; intros.
- apply vforall_forall; intros; simpl in *.
- specialize (H i i0); apply H.
- rewrite vforall_forall in H0; specialize (H0 i).
- rewrite vforall_forall in H0; specialize (H0 i0).
- apply H0.
- apply H1.
- - Case "Var"%string.
- unfold substvar.
- destruct H.
- + subst.
- unfold substvar.
- match_destr; [ | congruence].
- refl_simpler.
- simpl; trivial.
- + destruct v; destruct v0.
- simpl in *.
- match_destr.
- red in e; subst.
- simpl; trivial.
- - Case "VectorApply"%string.
- destruct H; split; trivial.
- + apply IHdf2.
- * apply H1.
- * apply H0.
- - Case "MatrixApply"%string.
- destruct H; split; trivial.
- + apply IHdf2.
- * apply H1.
- * apply H0.
- - Case "VLossfun"%string.
- destruct H; split; trivial.
- + apply IHdf2.
- * apply H1.
- * apply H0.
- - Case "MLossfun"%string.
- destruct H; split; trivial.
- + apply IHdf2.
- * apply H1.
- * apply H0.
- Qed.
-
- Lemma fully_closed_deriv {T} (df:DefinedFunction UnitAnn T) (s:SubVar)
- (vl : list var_type):
- fully_closed_over df vl ->
- fully_closed_over (df_deriv df (s, DTfloat)) vl.
- Proof.
- revert s; revert vl.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case
- ; simpl; intros; try solve
- [easy
- |
- apply IHdf,H
- |
- split; try easy; now apply IHdf
- |
- destruct H; repeat split; try easy;
- [apply IHdf2, H0 | apply IHdf1, H]
- |
- destruct H; repeat split; try easy;
- [ apply IHdf1, H | apply IHdf2, H0]
- ].
- - Case "DVector"%string.
- apply vforall_forall; intros.
- apply H.
- rewrite vforall_forall in H0.
- now apply H0.
- - Case "DMatrix"%string.
- apply vforall_forall; intros.
- apply vforall_forall; intros.
- apply H.
- rewrite vforall_forall in H0.
- specialize (H0 i).
- rewrite vforall_forall in H0.
- now specialize (H0 i0).
- - Case "Max"%string.
- destruct H; repeat split; try easy.
- apply IHdf2, H0.
- apply IHdf1, H.
- apply IHdf2, H0.
- apply IHdf1, H.
- - Case "VectorApply"%string.
- apply vforall_forall; intros; simpl in *.
- split; destruct H.
- apply IHdf2, H0.
- apply fully_closed_subst.
- + apply IHdf1.
- now apply fully_closed_over_singleton.
- + simpl; apply H0.
- - Case "MatrixApply"%string.
- apply vforall_forall; intros; simpl in *.
- apply vforall_forall; intros; simpl in *.
- split; destruct H.
- apply IHdf2, H0.
- apply fully_closed_subst.
- + apply IHdf1.
- now apply fully_closed_over_singleton.
- + simpl; apply H0.
- - Case "VLossfun"%string.
- intros; simpl in *.
- split; destruct H.
- apply IHdf2, H0.
- apply vforall_forall; intros; simpl in *.
- apply fully_closed_subst.
- apply fully_closed_subst.
- + apply IHdf1.
- now apply fully_closed_over_pair.
- + simpl; apply fully_closed_over_cons; apply H0.
- + now simpl.
- - Case "MLossfun"%string.
- intros; simpl in *.
- apply vforall_forall; intros; simpl in *.
- apply vforall_forall; intros; simpl in *.
- destruct H; split; trivial.
- split.
- apply IHdf2, H0.
- apply fully_closed_subst.
- apply fully_closed_subst.
- + apply IHdf1.
- now apply fully_closed_over_pair.
- + simpl; apply fully_closed_over_cons; apply H0.
- + now simpl.
- Qed.
-
- Lemma list_env_iter_total_fun {A} (f : A -> df_env -> option df_env) (env : df_env) (l : list A) :
- (forall (a:A) (env0: df_env), (f a env0) <> None) ->
- list_env_iter f (Some env) l <> None.
- Proof.
- intros.
- generalize env.
- induction l; [simpl; congruence|].
- simpl; intros.
- specialize (H a env0).
- case_eq (f a env0).
- - intros; apply (IHl d).
- - tauto.
- Qed.
-
- Lemma backprop_deriv_fully_closed_not_none {T} (σ:df_env) (df:DefinedFunction UnitAnn T)
- (grad_env:df_env) (grad: definition_function_types_interp T):
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl -> df_eval_backprop_deriv σ df grad_env grad <> None.
- Proof.
- revert grad_env.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case
- ; simpl; intros;
- try solve [congruence
- |
- specialize (IHdf1 grad grad_env); simpl in IHdf1; destruct H;
- match_option; [|tauto];
- specialize (IHdf2 grad d); simpl in IHdf2;
- now apply IHdf2
- ].
- - Case "DVector"%string.
- unfold two_vector_env_iter_alt.
- rewrite vforall_forall in H0.
- apply (list_env_iter_total_fun
- (fun i env => df_eval_backprop_deriv σ (x i) env (grad i))
- grad_env (bounded_seq0 n)).
- intros.
- apply (H a (grad a) env0).
- apply (H0 a).
- - Case "DMatrix"%string.
- unfold two_matrix_env_iter_alt.
- rewrite vforall_forall in H0.
- apply (list_env_iter_total_fun
- (fun i env =>
- list_env_iter
- (fun j env0 =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (Some env) (bounded_seq0 m))
- grad_env (bounded_seq0 n)).
- intros.
- apply (list_env_iter_total_fun
- (fun j env => df_eval_backprop_deriv σ (x a j) env (grad a j))
- env0 (bounded_seq0 m)).
- intros.
- apply (H a a0 (grad a a0) env1).
- specialize (H0 a).
- rewrite vforall_forall in H0.
- apply (H0 a0).
- - Case "Var"%string.
- match_destr.
- - Case "Minus"%string.
- specialize (IHdf1 grad grad_env); simpl in IHdf1; destruct H;
- match_option; [|tauto];
- specialize (IHdf2 (-grad)%R d); simpl in IHdf2;
- now apply IHdf2.
- - Case "Times"%string.
- destruct H.
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2.
- + match_option; [|tauto].
- specialize (IHdf1 (d0 * grad)%R grad_env); simpl in IHdf1.
- match_option; [|tauto].
- specialize (IHdf2 (d * grad)%R d1); simpl in IHdf2.
- now apply IHdf2.
- - Case "Divide"%string.
- destruct H.
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2.
- + match_option; [|tauto].
- specialize (IHdf1 (grad / d0)%R grad_env); simpl in IHdf1.
- match_option; [|tauto].
- specialize (IHdf2 (-d / (d0 * d0) * grad)%R d1); simpl in IHdf2.
- now apply IHdf2.
- - Case "Square"%string.
- generalize (eval_fully_closed_not_none σ df); intros; simpl in H0.
- match_option; [|tauto].
- specialize (IHdf (2 * d * grad)%R grad_env); simpl in IHdf.
- now apply IHdf.
- - Case "Exp"%string.
- generalize (eval_fully_closed_not_none σ df); intros; simpl in H0.
- match_option; [|tauto].
- specialize (IHdf (grad * exp d)%R grad_env); simpl in IHdf.
- now apply IHdf.
- - Case "Log"%string.
- generalize (eval_fully_closed_not_none σ df); intros; simpl in H0.
- match_option; [|tauto].
- specialize (IHdf (grad / d)%R grad_env); simpl in IHdf.
- now apply IHdf.
- - Case "Abs"%string.
- generalize (eval_fully_closed_not_none σ df); intros; simpl in H0.
- match_option; [|tauto].
- specialize (IHdf (grad * sign d)%R grad_env); simpl in IHdf.
- now apply IHdf.
- - Case "Sign"%string.
- specialize (IHdf (0)%R grad_env); simpl in IHdf.
- now apply IHdf.
- - Case "PSign"%string.
- specialize (IHdf (0)%R grad_env); simpl in IHdf.
- now apply IHdf.
- - Case "Max"%string.
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- match_option; [|tauto].
- match_destr.
- specialize (IHdf2 grad grad_env).
- now apply IHdf2.
- specialize (IHdf1 grad grad_env).
- now apply IHdf1.
- - Case "VectorDot"%string.
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- match_option; [|tauto].
- specialize (IHdf1 (vmap (fun rv : R => (rv * grad)%R) d0) grad_env); simpl in IHdf1.
- match_option; [|tauto].
- specialize (IHdf2 (vmap (fun lv : R => (lv * grad)%R) d) d1).
- now apply IHdf2.
- - Case "VectorSum"%string.
- specialize (IHdf (ConstVector n grad) grad_env).
- now apply IHdf.
- - Case "MatrixSum"%string.
- specialize (IHdf (ConstMatrix m n grad) grad_env).
- now apply IHdf.
- - Case "VectorElem"%string.
- specialize (IHdf (fun k : {n' : nat | n' < n} => if equiv_dec (` k) (` i) then grad else 0%R) grad_env).
- now apply IHdf.
- - Case "MatrixElem"%string.
- specialize (IHdf (fun (k1 : {n' : nat | n' < m})
- (k2 : {m' : nat | m' < n}) =>
- if equiv_dec (` k1) (` i) then if
- equiv_dec (` k2) (` j) then grad else 0%R else 0%R)
- grad_env).
- now apply IHdf.
- - Case "MatrixVectorMult"%string.
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- match_option; [|tauto].
- specialize (IHdf1 (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) => (grad i * d0 j)%R)
- grad_env); simpl in IHdf1.
- match_option; [|tauto].
- specialize (IHdf2 (matrix_vector_mult
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => d j i)
- grad)
- d1).
- now apply IHdf2.
- - Case "MatrixVectorAdd"%string.
- specialize (IHdf1 grad grad_env); simpl in IHdf1.
- match_option; [|tauto].
- match_option.
- generalize (list_env_iter_total_fun
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (@transpose R m n grad i))
- d (bounded_seq0 n)); intros.
- cut_to H0; [congruence|].
- intros; destruct H.
- now apply IHdf2.
- - Case "MatrixMult"%string.
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- match_option; [|tauto].
- specialize (IHdf1 (matrix_mult grad (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < p}) => d0 j i))
- grad_env); simpl in IHdf1.
- match_option; [|tauto].
- specialize (IHdf2 (matrix_mult (fun (i : {n' : nat | n' < p})
- (j : {m' : nat | m' < m}) => d j i) grad)
- d1).
- now apply IHdf2.
- - Case "VectorMinus"%string.
- specialize (IHdf1 grad grad_env); simpl in IHdf1.
- match_option; [|tauto].
- specialize (IHdf2 (fun i : {n' : nat | n' < n} => (- grad i)%R) d).
- now apply IHdf2.
- - Case "MatrixMinus"%string.
- specialize (IHdf1 grad grad_env); simpl in IHdf1.
- match_option; [|tauto].
- specialize (IHdf2 (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (- grad i j)%R)
- d).
- now apply IHdf2.
- - Case "VectorScalMult"%string.
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- match_option; [|tauto].
- specialize (IHdf1 (vsum (fun j : {n' : nat | n' < n} => (d0 j * grad j)%R))
- grad_env); simpl in IHdf1.
- match_option; [|tauto].
- specialize (IHdf2 (fun j : {n' : nat | n' < n} => (d * grad j)%R) d1).
- now apply IHdf2.
- - Case "MatrixScalMult"%string.
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H0.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- match_option; [|tauto].
- specialize (IHdf1 (msum
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (d0 i j * grad i j)%R))
- grad_env); simpl in IHdf1.
- match_option; [|tauto].
- specialize (IHdf2 (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad i j * d)%R)
- d1).
- now apply IHdf2.
- - Case "VectorApply"%string.
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H0.
- match_option; [|tauto].
- match_option.
- + specialize (IHdf2 v0 grad_env).
- now apply IHdf2.
- + apply vectoro_to_ovector_exists_None in eqq0.
- destruct eqq0.
- rewrite vmap_nth in e; simpl in e.
- destruct H.
- match_option_in e.
- generalize (fully_closed_deriv df1 v ((v,DTfloat):: nil)); intros.
- cut_to H2; trivial.
- generalize (eval_fully_closed_not_none (mk_env_entry (v, DTfloat) (d x) :: nil)
- (df_deriv df1 (v, DTfloat))); intros.
- simpl in H3; cut_to H3; tauto.
- - Case "MatrixApply"%string.
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H0.
- match_option; [|tauto].
- match_option.
- + specialize (IHdf2 m0 grad_env).
- now apply IHdf2.
- + unfold matrixo_to_omatrix in eqq0.
- apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0.
- apply vectoro_to_ovector_exists_None in e; destruct e.
- unfold mmap in e.
- rewrite vmap_nth in e; simpl in e.
- rewrite vmap_nth in e; simpl in e.
- destruct H.
- unfold matrix_zip in e.
- rewrite vmap_nth in e; simpl in e.
- match_option_in e.
- generalize (fully_closed_deriv df1 v ((v,DTfloat):: nil)).
- intros; cut_to H2; trivial.
- generalize (eval_fully_closed_not_none (mk_env_entry (v, DTfloat) (d x x0) :: nil)
- (df_deriv df1 (v, DTfloat))); intros.
- simpl in H3; cut_to H3; tauto.
- - Case "VLossfun"%string.
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H0.
- match_option; [|tauto].
- match_option.
- + specialize (IHdf2 v grad_env).
- now apply IHdf2.
- + apply vectoro_to_ovector_exists_None in eqq0.
- destruct eqq0.
- rewrite vmap_nth in e; simpl in e.
- destruct H.
- match_option_in e.
- generalize (fully_closed_deriv df1 v1 ((v1,DTfloat)::(v2,DTfloat)::nil)).
- intros; cut_to H2; trivial.
- generalize (eval_fully_closed_not_none (mk_env_entry (v1, DTfloat) (d x) ::
- mk_env_entry (v2, DTfloat) (r x) :: nil)
- (df_deriv df1 (v1, DTfloat))); intros.
- simpl in H3; cut_to H3; tauto.
- - Case "MLossfun"%string.
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H0.
- match_option; [|tauto].
- match_option.
- + specialize (IHdf2 m0 grad_env).
- now apply IHdf2.
- + unfold matrixo_to_omatrix in eqq0.
- apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0.
- apply vectoro_to_ovector_exists_None in e; destruct e.
- unfold mmap in e.
- rewrite vmap_nth in e; simpl in e.
- rewrite vmap_nth in e; simpl in e.
- destruct H.
- unfold matrix_zip in e.
- rewrite vmap_nth in e; simpl in e.
- match_option_in e.
- generalize (fully_closed_deriv df1 v1 ((v1,DTfloat)::(v2,DTfloat)::nil)).
- intros; cut_to H2; trivial.
- generalize (eval_fully_closed_not_none (mk_env_entry (v1, DTfloat) (d x x0) ::
- mk_env_entry (v2, DTfloat) (r x x0) :: nil)
- (df_deriv df1 (v1, DTfloat))); intros.
- simpl in H3; cut_to H3; tauto.
- Qed.
-
- Lemma backprop_deriv_fully_closed_total {T} (σ:df_env) (df:DefinedFunction UnitAnn T)
- (grad_env:df_env) (grad: definition_function_types_interp T):
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl ->
- {d:df_env | df_eval_backprop_deriv σ df grad_env grad = Some d}.
- Proof.
- case_eq (df_eval_backprop_deriv σ df grad_env grad); intros.
- - now exists d.
- - generalize (backprop_deriv_fully_closed_not_none σ df grad_env grad).
- intros; simpl in *.
- cut_to H1; tauto.
- Qed.
-
- Lemma eval_deriv_fully_closed_not_none {T} (σ:df_env) (df:DefinedFunction UnitAnn T)
- (v:var_type):
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl -> df_eval_deriv σ df v <> None.
- Proof.
- revert σ v.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case
- ; intros; simpl in *;
- try solve [
- congruence
- |
- destruct H
- ; specialize (IHdf1 σ v); specialize (IHdf2 σ v)
- ;cut_to IHdf1; trivial
- ;match_option; [|tauto]
- ;cut_to IHdf2; trivial
- ;match_option; tauto
- |
- destruct H;
- specialize (IHdf1 σ v); specialize (IHdf2 σ v);
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1;
- cut_to H1; trivial;
- match_option; [|tauto];
- cut_to IHdf1; trivial;
- match_option; [|tauto];
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2;
- match_option; [|tauto];
- cut_to IHdf2; trivial;
- match_option; tauto
- |
- generalize (eval_fully_closed_not_none σ df); intros;
- specialize (IHdf σ v);
- simpl in H0; cut_to H0; trivial;
- match_option; [|tauto];
- cut_to IHdf; trivial;
- match_option; tauto
- |
- specialize (IHdf σ v);
- generalize (eval_fully_closed_not_none σ df); intros;
- simpl in H0; cut_to H0; trivial;
- match_option; tauto
- ].
- - Case "DVector"%string.
- apply vectoro_to_ovector_not_none; intros; apply H.
- rewrite vforall_forall in H0; apply H0.
- - Case "DMatrix"%string.
- apply vectoro_to_ovector_not_none; intros.
- apply vectoro_to_ovector_not_none; intros; apply H.
- rewrite vforall_forall in H0; specialize (H0 i).
- rewrite vforall_forall in H0; apply H0.
- - Case "Max"%string.
- destruct H;
- specialize (IHdf1 σ v); specialize (IHdf2 σ v);
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1;
- cut_to H1; trivial;
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2;
- match_option; [|tauto].
- case_eq ( Rle_dec d d0 ); intros.
- cut_to IHdf2; trivial.
- cut_to IHdf1; trivial.
- - Case "VectorApply"%string.
- destruct H.
- specialize (IHdf2 σ v0);
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- cut_to H1; trivial.
- match_option; [|tauto].
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- apply vectoro_to_ovector_not_none; intros.
- match_option.
- specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i) :: nil) (v, DTfloat)).
- cut_to IHdf1; trivial.
- tauto.
- - Case "MatrixApply"%string.
- destruct H.
- specialize (IHdf2 σ v0);
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- cut_to H1; trivial.
- match_option; [|tauto].
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- apply vectoro_to_ovector_not_none; intros.
- apply vectoro_to_ovector_not_none; intros.
- specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i i0) :: nil) (v, DTfloat)).
- cut_to IHdf1; trivial.
- match_option; tauto.
- - Case "VLossfun"%string.
- destruct H.
- specialize (IHdf2 σ v);
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- cut_to H1; trivial.
- match_option; [|tauto].
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- match_option.
- apply vectoro_to_ovector_not_none in eqq1.
- tauto.
- intros.
- specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i) :: mk_env_entry (v2, DTfloat) (r i) :: nil) (v1, DTfloat)).
- cut_to IHdf1; trivial.
- match_option; tauto.
- - Case "MLossfun"%string.
- destruct H.
- specialize (IHdf2 σ v);
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- cut_to H1; trivial.
- match_option; [|tauto].
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- match_option.
- unfold matrixo_to_omatrix in eqq1.
- apply vectoro_to_ovector_not_none in eqq1.
- tauto.
- intros.
- apply vectoro_to_ovector_not_none; intros.
- specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i i0) :: mk_env_entry (v2, DTfloat) (r i i0) :: nil)(v1, DTfloat)).
- cut_to IHdf1; trivial.
- match_option; tauto.
- Qed.
-
- Lemma eval_deriv_fully_closed_total {T} (σ:df_env) (df:DefinedFunction UnitAnn T)
- (v:var_type):
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl ->
- {d:definition_function_types_interp T | df_eval_deriv σ df v = Some d}.
- Proof.
- case_eq (df_eval_deriv σ df v); intros.
- - now exists d.
- - generalize (eval_deriv_fully_closed_not_none σ df v).
- intros; simpl in *.
- cut_to H1; tauto.
- Qed.
-
- Lemma eval_deriv_genvar_fully_closed_not_none {T} (σ:df_env) (df:DefinedFunction UnitAnn T)
- (v:df_env):
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl -> df_eval_deriv_genvar σ df v <> None.
- Proof.
- revert σ v.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case
- ; intros; simpl in *;
- try solve [
- congruence
- |
- destruct H
- ; specialize (IHdf1 σ v); specialize (IHdf2 σ v)
- ;cut_to IHdf1; trivial
- ;match_option; [|tauto]
- ;cut_to IHdf2; trivial
- ;match_option; tauto
- |
- destruct H;
- specialize (IHdf1 σ v); specialize (IHdf2 σ v);
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1;
- cut_to H1; trivial;
- match_option; [|tauto];
- cut_to IHdf1; trivial;
- match_option; [|tauto];
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2;
- match_option; [|tauto];
- cut_to IHdf2; trivial;
- match_option; tauto
- |
- generalize (eval_fully_closed_not_none σ df); intros;
- specialize (IHdf σ v);
- simpl in H0; cut_to H0; trivial;
- match_option; [|tauto];
- cut_to IHdf; trivial;
- match_option; tauto
- |
- specialize (IHdf σ v);
- generalize (eval_fully_closed_not_none σ df); intros;
- simpl in H0; cut_to H0; trivial;
- match_option; tauto
- ].
- - Case "DVector"%string.
- apply vectoro_to_ovector_not_none; intros; apply H.
- rewrite vforall_forall in H0; apply H0.
- - Case "DMatrix"%string.
- apply vectoro_to_ovector_not_none; intros.
- apply vectoro_to_ovector_not_none; intros; apply H.
- rewrite vforall_forall in H0; specialize (H0 i).
- rewrite vforall_forall in H0; apply H0.
- - Case "Max"%string.
- destruct H;
- specialize (IHdf1 σ v); specialize (IHdf2 σ v);
- generalize (eval_fully_closed_not_none σ df1); intros; simpl in H1;
- cut_to H1; trivial;
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H2;
- match_option; [|tauto].
- case_eq ( Rle_dec d d0 ); intros.
- cut_to IHdf2; trivial.
- cut_to IHdf1; trivial.
- - Case "VectorApply"%string.
- destruct H.
- specialize (IHdf2 σ v0);
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- cut_to H1; trivial.
- match_option; [|tauto].
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- apply vectoro_to_ovector_not_none; intros.
- match_option.
- specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i) :: nil) (mk_genvar_env v)).
- cut_to IHdf1; trivial.
- tauto.
- - Case "MatrixApply"%string.
- destruct H.
- specialize (IHdf2 σ v0);
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- cut_to H1; trivial.
- match_option; [|tauto].
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- apply vectoro_to_ovector_not_none; intros.
- apply vectoro_to_ovector_not_none; intros.
- specialize (IHdf1 (mk_env_entry (v, DTfloat) (d i i0) :: nil) (mk_genvar_env v)).
- cut_to IHdf1; trivial.
- match_option; tauto.
- - Case "VLossfun"%string.
- destruct H.
- specialize (IHdf2 σ v);
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- cut_to H1; trivial.
- match_option; [|tauto].
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- match_option.
- apply vectoro_to_ovector_not_none in eqq1.
- tauto.
- intros.
- specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i) :: mk_env_entry (v2, DTfloat) (r i) :: nil) (mk_genvar_env v1)).
- cut_to IHdf1; trivial.
- match_option; tauto.
- - Case "MLossfun"%string.
- destruct H.
- specialize (IHdf2 σ v);
- generalize (eval_fully_closed_not_none σ df2); intros; simpl in H1.
- cut_to H1; trivial.
- match_option; [|tauto].
- cut_to IHdf2; trivial.
- match_option; [|tauto].
- match_option.
- unfold matrixo_to_omatrix in eqq1.
- apply vectoro_to_ovector_not_none in eqq1.
- tauto.
- intros.
- apply vectoro_to_ovector_not_none; intros.
- specialize (IHdf1 (mk_env_entry (v1, DTfloat) (d i i0) :: mk_env_entry (v2, DTfloat) (r i i0) :: nil) (mk_genvar_env v1)).
- cut_to IHdf1; trivial.
- match_option; tauto.
- Qed.
-
- Definition scalarMult (T : definition_function_types) (c : float) :=
- match T return
- definition_function_types_interp T -> definition_function_types_interp T with
- | DTfloat => fun f => (c * f)%R
- | DTVector n => fun f => fun i => (c * f i)%R
- | DTMatrix n m => fun f => fun i j => (c * f i j)%R
- end.
-
- Definition dfti_gen_plus {T : definition_function_types} :=
- match T return
- definition_function_types_interp T -> definition_function_types_interp T ->
- definition_function_types_interp T with
- | DTfloat => fun f g => (f + g)%R
- | DTVector n => fun f g => fun i => (f i + g i)%R
- | DTMatrix n m => fun f g => fun i j => (f i j + g i j)%R
- end.
-
- Lemma subvar_addvar_scalar_neq (env : df_env) (oval : float) (s : SubVar) (v: var_type) (grad : definition_function_types_interp (snd v)) :
- let sv := (s, DTfloat) in
- vartlookup env sv = Some oval ->
- v <> sv ->
- subvar sv (vart_update env v (addvar v env grad)) oval = 0%R.
- Proof.
- intros.
- unfold subvar; simpl.
- rewrite lookup_update_neq; trivial.
- rewrite H.
- lra.
- Qed.
-
- Lemma subvar_addvar_scalar_eq (env : df_env) (s : SubVar) (oval grad : float) :
- let v := (s, DTfloat) in
- vartlookup env v = Some oval ->
- subvar v (vart_update env v (addvar v env grad)) oval = grad.
- Proof.
- intros.
- unfold subvar; simpl.
- rewrite lookup_update.
- unfold addvar; simpl.
- rewrite H.
- lra.
- Qed.
-
- Lemma split_subvar (env1 env2: df_env) (oval val1 : float) (s : SubVar) :
- let v := (s, DTfloat) in
- vartlookup env1 v = Some val1 ->
- subvar v env2 oval = (subvar v env2 val1 + subvar v env1 oval)%R.
- Proof.
- intros.
- unfold subvar; simpl.
- rewrite H.
- case_eq (vartlookup env2 v); intros.
- lra.
- lra.
- Qed.
-
- Lemma vsum_nil : vsum vnil = 0%R.
- Proof.
- reflexivity.
- Qed.
-
- Lemma vsum_mult {n} (v : Vector float n) (c : float) :
- (c * vsum v)%R = vsum (fun j => (c * v j)%R).
- Proof.
- unfold vsum, vector_fold_right1, Datatypes.id; simpl.
- induction n; [ | destruct n].
- - repeat rewrite vector_fold_right1_dep_0; lra.
- - repeat rewrite vector_fold_right1_dep_1; lra.
- - repeat rewrite vector_fold_right1_dep_SSn.
- rewrite Rmult_plus_distr_l.
- specialize (IHn (vdrop_last v)); simpl in IHn.
- rewrite IHn.
- f_equal.
- apply vector_fold_right1_dep_ext.
- intros [i pf]; trivial.
- Qed.
-
- Lemma vsum_plus {m:nat} (v1 v2:Vector R m) :
- (vsum v1 + vsum v2)%R = vsum (fun i => (v1 i + v2 i)%R).
- Proof.
- unfold vsum, vector_fold_right1, Datatypes.id; simpl.
- induction m; [ | destruct m].
- - repeat rewrite vector_fold_right1_dep_0; lra.
- - repeat rewrite vector_fold_right1_dep_1; lra.
- - repeat rewrite vector_fold_right1_dep_SSn.
- specialize (IHm (vdrop_last v1) (vdrop_last v2)); simpl in IHm.
- rewrite (Rplus_comm (vlast v2)).
- rewrite (Rplus_assoc (vlast v1)).
- rewrite <- (Rplus_assoc _ _ (vlast v2)).
- rewrite IHm.
- rewrite (Rplus_comm _ (vlast v2)).
- rewrite <- Rplus_assoc.
- f_equal.
- apply vector_fold_right1_dep_ext.
- intros [i pf]; trivial.
- Qed.
-
- Lemma vmap_mult {n} (f: float -> float) (v : Vector float n) (c : float) :
- forall i : {n' : nat | n' < n},
- (c * (vmap f v) i)%R = (vmap (fun x => (c * f x)%R) v) i.
- Proof.
- intros.
- rewrite vmap_nth.
- now rewrite vmap_nth.
- Qed.
-
- Lemma vsum_ext {n} (v v':Vector float n) : vec_eq v v' -> vsum v = vsum v'.
- Proof.
- apply vector_fold_right1_ext.
- Qed.
-
- Lemma msum_ext {m n} (mat mat':Matrix float m n) :
- (forall i j, mat i j = mat' i j) -> msum mat = msum mat'.
- Proof.
- intros.
- apply vsum_ext; intros ?.
- repeat rewrite vmap_nth.
- apply vsum_ext; intros ?; auto.
- Qed.
-
- Lemma msum_mult {m n} (mat : Matrix float m n) (c : float) :
- (c * msum mat)%R = msum (fun i j => (c * mat i j)%R).
- Proof.
- unfold msum.
- rewrite vsum_mult.
- apply vsum_ext; intros i.
- repeat rewrite vmap_nth.
- now rewrite vsum_mult.
- Qed.
-
- Lemma msum_mmap_mult {m n} (mat : Matrix float m n) (c : float) :
- (c * msum mat)%R = msum (mmap (fun x => c * x)%R mat).
- Proof.
- rewrite msum_mult.
- apply msum_ext; intros i j.
- now rewrite mmap_nth.
- Qed.
-
- Lemma msum_mmap_div_denom {m n} (mat : Matrix float m n) (c : float) :
- msum (mmap (fun u : R => (u / c)%R) mat) = (msum mat / c)%R.
- Proof.
- transitivity (msum (mmap (fun u : R => (/ c * u)%R) mat)).
- - apply msum_ext; intros i j.
- repeat rewrite mmap_nth.
- lra.
- - rewrite <- msum_mmap_mult.
- lra.
- Qed.
-
- Lemma vsum0 n : vsum (fun _ : {n' : nat | (n' < n)%nat} => 0%R) = 0%R.
- Proof.
- generalize (vsum_mult (fun _ : {n' : nat | (n' < n)%nat} => 0%R) 0%R); intros HH.
- rewrite Rmult_0_l in HH.
- symmetry.
- simpl in *.
- erewrite vsum_ext; [eassumption | ].
- intro; simpl; lra.
- Qed.
-
- Lemma vsum_unitvector {n} (v:Vector R n) i :
- vsum (fun j => (v j * UnitVector n i j)%R) = v i.
- Proof.
- unfold vsum, vector_fold_right1, Datatypes.id, UnitVector; simpl.
- revert n v i.
- destruct i.
- induction n; [ | destruct n].
- - lia.
- - repeat rewrite vector_fold_right1_dep_1.
- destruct x; [ | lia]; simpl.
- field_simplify.
- now erewrite index_pf_irrel.
- - repeat rewrite vector_fold_right1_dep_SSn.
- unfold vlast, vdrop_last; simpl.
- destruct (equiv_dec (S n) x).
- + ring_simplify.
- simpl.
- destruct e.
- match goal with
- | [|- (_ + ?x)%R = _ ] => replace x with 0%R
- end.
- * ring_simplify.
- now erewrite index_pf_irrel.
- * rewrite <- (vsum0 (S n)) at 1.
- unfold vsum, vector_fold_right1, Fzero, Datatypes.id; simpl.
- apply (@vector_fold_right1_dep_ext (fun _ => R)).
- intros [??].
- destruct (equiv_dec x (S n)).
- -- destruct e.
- lia.
- -- lra.
- + ring_simplify.
- unfold equiv, complement in c.
- assert (pf:x < S n) by lia.
- specialize (IHn (vdrop_last v) pf).
- simpl in IHn.
- erewrite index_pf_irrel.
- rewrite <- IHn.
- apply (@vector_fold_right1_dep_ext (fun _ => R)).
- now intros [??].
- Qed.
-
- Lemma msum_unitmatrix {m n} (v:Matrix R m n) i j :
- msum (fun k l => (v k l * UnitMatrix m n i j k l)%R) = v i j.
- Proof.
- unfold msum.
- unfold UnitMatrix.
- rewrite (vsum_ext _ (
- (fun (k : {n' : nat | n' < m}) => @vsum floatish_R _
- (fun (l : {m' : nat | m' < n}) =>
- (v k l *
- (if equiv_dec (` k) (` i) then if equiv_dec (` l) (` j) then 1%R else 0%R else 0%R))%R))
-
- ))
- by (intros ?; now rewrite vmap_nth).
- rewrite (vsum_ext _ (
- (fun (k : {n' : nat | n' < m}) => (if equiv_dec (` k) (` i) then
- @vsum floatish_R _
- (fun (l : {m' : nat | m' < n}%nat) =>
- ((v k) l *
- if equiv_dec (` l) (` j) then 1%R else 0%R))%R else 0%R))
-
- )).
- - rewrite (vsum_ext _ (
- (fun (k : {n' : nat | n' < m}) => (if equiv_dec (` k) (` i) then
- v k j else 0%R))
-
- )).
- + rewrite (vsum_ext _ (
- (fun (k : {n' : nat | n' < m}) => ((transpose v) j k * @UnitVector floatish_R m i k)%R)
-
- )).
- * now rewrite vsum_unitvector.
- * unfold UnitVector; intros ?; simpl.
- dest_eqdec; unfold transpose; simpl
- ; lra.
- + intros ?.
- dest_eqdec; trivial.
- apply vsum_unitvector.
- - intros ?.
- dest_eqdec; trivial.
- rewrite <- (vsum0 n) at 1.
- apply vsum_ext.
- intros ?; lra.
- Qed.
-
- Ltac vectoro_assert_forall_in H i
- := match type of H with vectoro_to_ovector ?x = Some ?y =>
- assert (forall i, x i = Some (y i)) end.
-
- Lemma vartlookup_list_env_iter {A}
- (s: SubVar)
- (f : A -> df_env -> option df_env)
- (l : list A) (env fenv: df_env):
- list_env_iter f (Some env) l = Some fenv ->
- vartlookup env (s, DTfloat) <> None ->
- (forall (env fenv: df_env) (i:A),
- f i env = Some fenv ->
- vartlookup env (s, DTfloat) <> None ->
- vartlookup fenv (s, DTfloat) <> None) ->
- vartlookup fenv (s, DTfloat) <> None.
- Proof.
- intros.
- revert H0 H.
- generalize env.
- induction l.
- - simpl; intros.
- now invcs H.
- - simpl; intros.
- generalize (list_env_iter_none f l); intros.
- assert (f a env0 <> None); [congruence | ].
- case_eq (f a env0); [|congruence].
- intros.
- apply (IHl d).
- + specialize (H1 env0 d a).
- now apply H1.
- + now rewrite H4 in H.
- Qed.
-
-(* Theorem df_eval_deriv_same {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (s:SubVar) :
- let v := (s, DTfloat) in
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl ->
- df_eval_deriv σ df v = df_eval σ (df_deriv df v).
- Proof.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case
- ; simpl in *;trivial; intros
- ; try (destruct H; rewrite IHdf1; trivial; rewrite IHdf2; trivial)
- ; try (rewrite IHdf; trivial; do 2 match_option).
- - Case "DVector"%string.
- f_equal.
- apply functional_extensionality; intros.
- rewrite vforall_forall in H0.
- specialize (H x0); simpl in H.
- specialize (H0 x0).
- apply H; trivial.
- - Case "DMatrix"%string.
- f_equal.
- apply functional_extensionality; intros.
- apply functional_extensionality; intros.
- rewrite vforall_forall in H0.
- specialize (H x0); simpl in H.
- specialize (H0 x0).
- rewrite vforall_forall in H0.
- apply H; trivial.
- - case "Times"%string.
- intros; do 4 match_option.
- - Case "Divide"%string.
- intros; do 4 match_option.
- - Case "Sign"%string.
- match_option.
- generalize (eval_deriv_fully_closed_not_none σ df (s, DTfloat)); tauto.
- - Case "PSign"%string.
- match_option.
- generalize (eval_deriv_fully_closed_not_none σ df (s, DTfloat)); tauto.
- - Case "Max"%string.
- assert (df_eval σ df1 <> None) by (apply eval_fully_closed_not_none; trivial).
- assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial).
- match_option; [|congruence].
- match_option; [|congruence].
- assert (df_eval σ (df_deriv df1 (s, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- apply fully_closed_deriv; trivial.
- assert (df_eval σ (df_deriv df2 (s, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- apply fully_closed_deriv; trivial.
- destruct (df_eval σ (df_deriv df1 (s, DTfloat))); [|congruence].
- destruct (df_eval σ (df_deriv df2 (s, DTfloat))); [|congruence].
- unfold pos_sign; simpl.
- case_eq (Rle_dec d d0); intros; f_equal.
- destruct (Rge_dec (d0 - d) 0); lra.
- destruct (Rge_dec (d0 - d) 0); lra.
- - Case "VectorDot"%string.
- do 4 match_option.
- rewrite vsum_plus.
- do 2 f_equal.
- f_equal.
- apply functional_extensionality; intros.
- lra.
- - Case "MatrixVectorMult"%string.
- do 4 match_option.
- f_equal.
- apply functional_extensionality; intros.
- unfold matrix_vector_mult.
- rewrite vsum_plus; simpl.
- f_equal.
- apply functional_extensionality; intros.
- lra.
- - Case "MatrixMult"%string.
- do 4 match_option.
- f_equal.
- apply functional_extensionality; intros.
- apply functional_extensionality; intros.
- unfold matrix_mult.
- rewrite vsum_plus; simpl; f_equal.
- apply functional_extensionality; intros.
- lra.
- - Case "VectorScalMult"%string.
- do 4 match_option.
- f_equal.
- apply functional_extensionality; intros.
- lra.
- - Case "MatrixScalMult"%string.
- do 4 match_option.
- f_equal.
- apply functional_extensionality; intros.
- apply functional_extensionality; intros.
- lra.
- - Case "VectorApply"%string.
- destruct H.
- assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial).
- match_option; [|congruence].
- rewrite IHdf2; trivial.
- match_option.
- + f_equal.
- apply functional_extensionality; intros.
- assert ( df_eval_deriv [mk_env_entry (v, DTfloat) (d x)] df1 (v, DTfloat) =
- df_eval σ (df_subst (df_deriv df1 (v, DTfloat)) (v, DTfloat)
- (VectorElem () df2 x))).
- XXX
- now rewrite H2.
- + assert ( df_eval σ (df_deriv df2 (s, DTfloat)) <> None ); [|tauto].
- apply eval_fully_closed_not_none.
- apply fully_closed_deriv; trivial.
- - Case "MatrixApply"%string.
- destruct H.
- assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial).
- match_option; [|congruence].
- rewrite IHdf2; trivial.
- match_option.
- + f_equal.
- apply functional_extensionality; intros.
- apply functional_extensionality; intros.
- assert ( df_eval_deriv [mk_env_entry (v, DTfloat) (d x x0)] df1 (v, DTfloat) =
- df_eval σ (df_subst (df_deriv df1 (v, DTfloat)) (v, DTfloat)
- (MatrixElem () df2 x x0))).
- XXX
- now rewrite H2.
- + assert ( df_eval σ (df_deriv df2 (s, DTfloat)) <> None ); [|tauto].
- apply eval_fully_closed_not_none.
- apply fully_closed_deriv; trivial.
- - Case "VLossfun"%string.
- destruct H.
- assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial).
- match_option; [|congruence].
- rewrite IHdf2; trivial.
- match_option.
- do 2 match_option.
- + do 2 f_equal.
- apply functional_extensionality; intros.
- XXX
- + generalize (vectoro_to_ovector_exists_None eqq2).
- intros; destruct H2.
- assert (df_eval
- σ
- (df_subst (df_subst (df_deriv df1 (v1, DTfloat)) (v1, DTfloat) (VectorElem () df2 x))
- (v2, DTfloat) (Number () (r x))) <> None); [|tauto].
- apply eval_fully_closed_not_none.
- apply fully_closed_subst; simpl; [|trivial].
- apply fully_closed_subst; simpl.
- * apply fully_closed_deriv.
- now apply fully_closed_over_pair.
- * now apply fully_closed_over_cons.
- + generalize (vectoro_to_ovector_exists_None eqq1).
- intros; destruct H2.
- match_option_in e.
- assert (df_eval_deriv [mk_env_entry (v1, DTfloat) (d x); mk_env_entry (v2, DTfloat) (r x)]
- df1 (v1, DTfloat) <> None); [|tauto].
- apply eval_deriv_fully_closed_not_none; trivial.
- - Case "MLossfun"%string.
- destruct H.
- assert (df_eval σ df2 <> None) by (apply eval_fully_closed_not_none; trivial).
- match_option; [|congruence].
- rewrite IHdf2; trivial.
- match_option.
- do 2 match_option.
- + do 2 f_equal.
- XXX
- + generalize (vectoro_to_ovector_exists_None eqq2); intros; destruct H2.
- generalize (vectoro_to_ovector_exists_None e); intros; destruct H2.
- match_option_in e0.
- match_option_in eqq3.
- assert (df_eval σ
- (df_subst
- (df_subst (df_deriv df1 (v1, DTfloat)) (v1, DTfloat) (MatrixElem () df2 x x0))
- (v2, DTfloat) (Number () (r x x0))) <> None); [|tauto].
- apply eval_fully_closed_not_none.
- apply fully_closed_subst; simpl; [|trivial].
- apply fully_closed_subst; simpl.
- * apply fully_closed_deriv.
- now apply fully_closed_over_pair.
- * now apply fully_closed_over_cons.
- + generalize (vectoro_to_ovector_exists_None eqq1).
- intros; destruct H2.
- generalize (vectoro_to_ovector_exists_None e).
- intros; destruct H2.
- match_option_in e0.
- assert (df_eval_deriv [mk_env_entry (v1, DTfloat) (d x x0);
- mk_env_entry (v2, DTfloat) (r x x0)]
- df1 (v1, DTfloat) <> None); [|tauto].
- apply eval_deriv_fully_closed_not_none; trivial.
-
-
- *)
-
- Theorem df_eval_deriv_scalar_same (σ:df_env) (df:DefinedFunction UnitAnn DTfloat) (s:SubVar) :
- let v := (s, DTfloat) in
- let vl := map (fun ve => projT1 ve) σ in
- is_scalar_function df ->
- fully_closed_over df vl ->
- df_eval_deriv σ df v = df_eval σ (df_deriv df v).
- Proof.
- simpl.
- intros is_scalar.
- generalize is_scalar.
- pattern df.
- revert df is_scalar.
- DefinedFunction_scalar_cases (apply is_scalar_function_ind) Case
- ; simpl; trivial; intros
- ; try (destruct H1; destruct is_scalar;
- specialize (H H3 H1); specialize (H0 H4 H2); now rewrite H, H0)
- ; try (rewrite H; trivial; do 2 match_option).
- - Case "Times"%string.
- destruct H1.
- destruct is_scalar.
- cut_to H; trivial.
- cut_to H0; trivial.
- rewrite H, H0.
- do 4 match_option.
- - Case "Divide"%string.
- destruct H1.
- destruct is_scalar.
- cut_to H; trivial.
- cut_to H0; trivial.
- rewrite H, H0.
- do 4 match_option.
- - Case "Sign"%string.
- match_option.
- assert ( df_eval_deriv σ e (s, DTfloat) <> None); [|tauto].
- apply eval_deriv_fully_closed_not_none; trivial.
- - Case "PSign"%string.
- match_option.
- assert ( df_eval_deriv σ e (s, DTfloat) <> None); [|tauto].
- apply eval_deriv_fully_closed_not_none; trivial.
- - Case "Max"%string.
- destruct is_scalar; destruct H1.
- cut_to H; trivial.
- cut_to H0; trivial.
- rewrite H, H0.
- assert (df_eval σ l <> None) by (apply eval_fully_closed_not_none; trivial).
- assert (df_eval σ r <> None) by (apply eval_fully_closed_not_none; trivial).
- match_option; [|congruence].
- match_option; [|congruence].
- assert (df_eval σ (df_deriv l (s, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- apply fully_closed_deriv; trivial.
- assert (df_eval σ (df_deriv r (s, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- apply fully_closed_deriv; trivial.
- destruct (df_eval σ (df_deriv l (s, DTfloat))); [|congruence].
- destruct (df_eval σ (df_deriv r (s, DTfloat))); [|congruence].
- unfold pos_sign; simpl.
- case_eq (Rle_dec d d0); intros; f_equal.
- destruct (Rge_dec (d0 - d) 0); lra.
- destruct (Rge_dec (d0 - d) 0); lra.
- Qed.
-
- Lemma vartlookup_list_env_iter2 {A}
- (s: SubVar)
- {f : A -> df_env -> option df_env}
- {l : list A} {env fenv: df_env}:
- list_env_iter f (Some env) l = Some fenv ->
- vartlookup env (s, DTfloat) <> None ->
- (forall (env fenv: df_env) (i:A),
- f i env = Some fenv ->
- vartlookup env (s, DTfloat) <> None ->
- vartlookup fenv (s, DTfloat) <> None) ->
- vartlookup fenv (s, DTfloat) <> None.
- Proof.
- apply (vartlookup_list_env_iter s f l env fenv).
- Qed.
-
- Lemma scalarMult_list_env_iter
- (s: SubVar) (c val1 val2:float) (A :Type)
- (f g : A -> df_env -> option df_env)
- (l : list A) (env1 env2 fenv1 fenv2: df_env):
- list_env_iter f (Some env1) l = Some fenv1 ->
- list_env_iter g (Some env2) l = Some fenv2 ->
- vartlookup env1 (s, DTfloat) = Some val1 ->
- vartlookup env2 (s, DTfloat) = Some val2 ->
- (forall (i:A) (env1 env2 fenv1 fenv2: df_env) (v1 v2: float),
- vartlookup env1 (s, DTfloat) = Some v1 ->
- vartlookup env2 (s, DTfloat) = Some v2 ->
- f i env1 = Some fenv1 -> g i env2 = Some fenv2 ->
- subvar (s, DTfloat) fenv1 v1 = (c * subvar (s, DTfloat) fenv2 v2)%R) ->
- (forall (env fenv: df_env) (i:A),
- f i env = Some fenv ->
- vartlookup env (s, DTfloat) <> None ->
- vartlookup fenv (s, DTfloat) <> None) ->
- (forall (env fenv: df_env) (i:A),
- g i env = Some fenv ->
- vartlookup env (s, DTfloat) <> None ->
- vartlookup fenv (s, DTfloat) <> None) ->
- subvar (s, DTfloat) fenv1 val1 = (c * subvar (s, DTfloat) fenv2 val2)%R.
- Proof.
- intros.
- generalize (vartlookup_list_env_iter s f l env1 fenv1); intros.
- assert (vartlookup env1 (s, DTfloat) <> None).
- rewrite H1; discriminate.
- assert (vartlookup env2 (s, DTfloat) <> None).
- rewrite H2; discriminate.
- specialize (H6 H H7 H4).
- generalize (vartlookup_list_env_iter s g l env2 fenv2); intros.
- specialize (H9 H0 H8 H5).
- revert H1 H2 H H0.
- generalize env1 env2 val1 val2.
- induction l.
- - intros.
- unfold subvar; simpl.
- unfold list_env_iter in H; simpl in H.
- unfold list_env_iter in H0; simpl in H0.
- invcs H; invcs H0.
- rewrite H1; rewrite H2; lra.
- - simpl; intros.
- generalize (list_env_iter_none f l); intros.
- assert (f a env0 <> None); [congruence | ].
- case_eq (f a env0); [intros|congruence].
- generalize (list_env_iter_none g l); intros.
- assert (g a env3 <> None); [congruence | ].
- case_eq (g a env3); [intros|congruence].
- assert (vartlookup d (s, DTfloat) <> None).
- apply (H4 env0 d a); trivial; congruence.
- assert (vartlookup d0 (s, DTfloat) <> None).
- apply (H5 env3 d0 a); trivial; congruence.
- case_eq (vartlookup d (s, DTfloat)); [intros | tauto].
- case_eq (vartlookup d0 (s, DTfloat)); [intros | tauto].
- specialize (IHl d d0 d1 d2).
- specialize (H3 a env0 env3 d d0 val0 val3).
- rewrite (split_subvar d fenv1 val0 d1); trivial.
- rewrite (split_subvar d0 fenv2 val3 d2); trivial.
- specialize (H3 H1 H2 H12 H15).
- rewrite H12 in H.
- rewrite H15 in H0.
- specialize (IHl H18 H19 H H0).
- lra.
- Qed.
-
- Lemma list_env_iter_subvar_env2
- (s: SubVar) (val1 val2:float) (A :Type)
- (f g : A -> df_env -> option df_env)
- (l : list A) (env1 env2 fenv1 fenv2: df_env):
- list_env_iter f (Some env1) l = Some fenv1 ->
- list_env_iter g (Some env2) l = Some fenv2 ->
- vartlookup env1 (s, DTfloat) = Some val1 ->
- vartlookup env2 (s, DTfloat) = Some val2 ->
- (forall (i:A) (env1 env2 fenv1 fenv2: df_env) (v1 v2: float),
- vartlookup env1 (s, DTfloat) = Some v1 ->
- vartlookup env2 (s, DTfloat) = Some v2 ->
- f i env1 = Some fenv1 -> g i env2 = Some fenv2 ->
- subvar (s, DTfloat) fenv1 v1 = subvar (s, DTfloat) fenv2 v2) ->
- (forall (env fenv: df_env) (i:A),
- f i env = Some fenv ->
- vartlookup env (s, DTfloat) <> None ->
- vartlookup fenv (s, DTfloat) <> None) ->
- (forall (env fenv: df_env) (i:A),
- g i env = Some fenv ->
- vartlookup env (s, DTfloat) <> None ->
- vartlookup fenv (s, DTfloat) <> None) ->
- subvar (s, DTfloat) fenv1 val1 = subvar (s, DTfloat) fenv2 val2.
- Proof.
- intros.
- generalize (scalarMult_list_env_iter s 1%R val1 val2 A f g l env1 env2 fenv1 fenv2).
- intros.
- specialize (H6 H H0 H1 H2).
- cut_to H6.
- now replace (1 * subvar (s, DTfloat) fenv2 val2)%R with (subvar (s, DTfloat) fenv2 val2) in H6 by lra.
- intros.
- replace (1 * subvar (s, DTfloat) fenv3 v2)%R with (subvar (s, DTfloat) fenv3 v2) by lra.
- apply (H3 i env0 env3 fenv0 fenv3); trivial.
- apply H4.
- apply H5.
- Qed.
-
- Lemma scalarMult_backprop_grad_scalar {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (s: SubVar) (grad_env1 grad_env2:df_env) (grad : definition_function_types_interp T) (c:float) :
- let v := (s, DTfloat) in
- vartlookup grad_env1 v <> None -> vartlookup grad_env2 v <> None ->
- df_eval_backprop_deriv σ df grad_env1 (scalarMult T c grad) <> None ->
- df_eval_backprop_deriv σ df grad_env2 grad <> None ->
- df_eval_backprop_delta σ df v grad_env1 (scalarMult T c grad) =
- lift (fun e => scalarMult (snd v) c e) (df_eval_backprop_delta σ df v grad_env2 grad).
- Proof.
- revert grad_env1 grad_env2.
- unfold df_eval_backprop_delta.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case
- ; simpl; intros grad_env1 grad_env2 neq1 neq2; intros.
- - Case "Number"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [|tauto].
- intros; simpl; f_equal.
- unfold subvar; simpl.
- match_destr; match_destr.
- inversion H1; subst.
- inversion H2; subst.
- lra.
- - Case "Constant"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- intros; simpl; f_equal.
- unfold subvar; simpl.
- match_destr; match_destr.
- inversion H1; subst.
- inversion H2; subst.
- lra.
- - Case "DVector"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- intros; simpl.
- unfold lift.
- match_option; [|tauto].
- case_eq (two_vector_env_iter_alt
- (fun (x0 : DefinedFunction Ann DTfloat) (g : R) (genv : df_env) =>
- df_eval_backprop_deriv σ x0 genv g) grad_env2 x grad); [|tauto].
- unfold two_vector_env_iter_alt in *.
- intros; f_equal.
- apply (scalarMult_list_env_iter
- s c d0 d {n' : nat | n' < n}
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (c * grad i)%R)
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (grad i))
- (bounded_seq0 n) grad_env1 grad_env2); trivial.
- + intros.
- specialize (H i (grad i) env1 env2).
- assert (vartlookup env1 (s, DTfloat) <> None); [congruence|].
- assert (vartlookup env2 (s, DTfloat) <> None); [congruence|].
- specialize (H H9 H10).
- assert (df_eval_backprop_deriv σ (x i) env1 (c * grad i)%R <> None); [congruence|].
- assert (df_eval_backprop_deriv σ (x i) env2 (grad i) <> None); [congruence|].
- specialize (H H11 H12).
- unfold lift in H; simpl in H.
- rewrite H5, H6, H7, H8 in H.
- now inversion H.
- + intros.
- now apply (df_eval_backprop_deriv_preserves_lookup_not_none H5).
- + intros.
- now apply (df_eval_backprop_deriv_preserves_lookup_not_none H5).
- - Case "DMatrix"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- intros; simpl.
- unfold lift.
- match_option; [|tauto].
- case_eq (two_matrix_env_iter_alt
- (fun (x0 : DefinedFunction Ann DTfloat) (g : R) (genv : df_env) =>
- df_eval_backprop_deriv σ x0 genv g) grad_env2 x grad); [|tauto].
- intros; f_equal.
- unfold two_matrix_env_iter_alt in *.
- apply (scalarMult_list_env_iter
- s c d0 d {n' : nat | n' < n}
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (c * grad i j)%R)
- (Some env) (bounded_seq0 m))
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j)) (Some env)
- (bounded_seq0 m))
- (bounded_seq0 n) grad_env1 grad_env2); trivial.
- + intros.
- apply (scalarMult_list_env_iter
- s c v1 v2 {m' : nat | m' < m}
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (c * grad i j)%R)
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (bounded_seq0 m) env1 env2); trivial.
- * intros.
- specialize (H i i0 (grad i i0) env0 env3).
- assert (vartlookup env0 (s, DTfloat) <> None); [congruence|].
- assert (vartlookup env3 (s, DTfloat) <> None); [congruence|].
- specialize (H H13 H14).
- assert (df_eval_backprop_deriv σ (x i i0) env0 (c * grad i i0)%R <> None); [congruence|].
- assert (df_eval_backprop_deriv σ (x i i0) env3 (grad i i0) <> None); [congruence|].
- specialize (H H15 H16).
- unfold lift in H; simpl in H.
- rewrite H9, H10, H11, H12 in H.
- now inversion H.
- * intros.
- now apply (df_eval_backprop_deriv_preserves_lookup_not_none H9).
- * intros.
- now apply (df_eval_backprop_deriv_preserves_lookup_not_none H9).
- + intros.
- apply (vartlookup_list_env_iter
- s
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (c * grad i j)%R)
- (bounded_seq0 m) env fenv); trivial; intros.
- now apply (df_eval_backprop_deriv_preserves_lookup_not_none H7).
- + intros.
- apply (vartlookup_list_env_iter
- s
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (bounded_seq0 m) env fenv); trivial; intros.
- now apply (df_eval_backprop_deriv_preserves_lookup_not_none H7).
- - Case "Var"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- intros.
- destruct (vart_dec v (s, DTfloat)).
- + subst; simpl.
- rewrite H2; simpl.
- rewrite subvar_addvar_scalar_eq; trivial.
- rewrite H1; simpl.
- now rewrite subvar_addvar_scalar_eq.
- + case_eq (vartlookup grad_env1 v); intros; simpl.
- * case_eq (vartlookup grad_env2 v); intros; simpl; f_equal.
- -- rewrite subvar_addvar_scalar_neq; trivial.
- rewrite subvar_addvar_scalar_neq; trivial.
- lra.
- -- rewrite subvar_addvar_scalar_neq; trivial.
- unfold subvar; simpl.
- rewrite H1.
- lra.
- * case_eq (vartlookup grad_env2 v); intros; simpl; f_equal.
- -- rewrite subvar_addvar_scalar_neq; trivial.
- unfold subvar; simpl.
- rewrite H2.
- lra.
- -- unfold subvar; simpl.
- rewrite H2; rewrite H1.
- lra.
- - Case "Plus"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- specialize (IHdf1 grad grad_env1 grad_env2); intros.
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1 (c * grad)%R); intros.
- rewrite H3 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros.
- rewrite H4 in H0; simpl in H0.
- + specialize (IHdf2 grad d1 d2).
- rewrite H3 in IHdf1; rewrite H4 in IHdf1.
- assert (vartlookup d1 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1).
- case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2).
- case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros.
- rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d1 (c * grad)%R); [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d2 grad); [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d1 d5 d0 d3) by trivial.
- rewrite (split_subvar d2 d6 d d4) by trivial.
- rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H11; inversion H12.
- rewrite H14; rewrite H15; lra.
- + now rewrite H4 in H0.
- + now rewrite H3 in H.
- - Case "Minus"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- specialize (IHdf1 grad grad_env1 grad_env2); intros.
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1 (c * grad)%R); intros.
- rewrite H3 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros.
- rewrite H4 in H0; simpl in H0.
- + specialize (IHdf2 (-grad)%R d1 d2).
- rewrite H3 in IHdf1; rewrite H4 in IHdf1.
- assert (vartlookup d1 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1).
- case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2).
- case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros.
- rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d1 (-(c * grad))%R); [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d2 (-grad)%R); [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d1 d5 d0 d3) by trivial.
- rewrite (split_subvar d2 d6 d d4) by trivial.
- replace (c * -grad)%R with (-(c*grad))%R in IHdf2 by lra.
- rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H11; inversion H12.
- rewrite H14; rewrite H15; lra.
- + now rewrite H4 in H0.
- + now rewrite H3 in H.
- - Case "Times"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- case_eq ( df_eval σ df1); [|tauto].
- case_eq ( df_eval σ df2); [|tauto]; intros.
- specialize (IHdf1 (d * grad)%R grad_env1 grad_env2).
- rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1 (d * (c * grad))%R); intros.
- rewrite H1, H2, H5 in H; simpl in H.
- rewrite H1, H2 in H0; simpl in H0.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 (d * grad)%R); intros.
- rewrite H6 in H0; simpl in H0.
- + specialize (IHdf2 (d0 * grad)%R d3 d4).
- replace (c * (d * grad))%R with (d * (c * grad))%R in IHdf1 by lra.
- rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1.
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1).
- case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2).
- case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros.
- rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d3 (d0 * (c * grad))%R); [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d4 (d0 * grad)%R); [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d3 d7 d2 d5) by trivial.
- rewrite (split_subvar d4 d8 d1 d6) by trivial.
- replace (c * (d0 * grad))%R with (d0 * (c*grad))%R in IHdf2 by lra.
- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d3 d2) = Some (c * subvar (s, DTfloat) d4 d1)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d7 d5) = Some (c * subvar (s, DTfloat) d8 d6)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H13; inversion H14.
- rewrite H16; rewrite H17; lra.
- + now rewrite H6 in H0.
- + rewrite H2 in H; rewrite H1 in H.
- now rewrite H5 in H.
- - Case "Divide"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- case_eq ( df_eval σ df1); [|tauto].
- case_eq ( df_eval σ df2); [|tauto]; intros.
- specialize (IHdf1 (grad / d)%R grad_env1 grad_env2).
- rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1 (c * grad / d)%R); intros.
- rewrite H1 in H; rewrite H2 in H; simpl in H.
- rewrite H1 in H0; rewrite H2 in H0; simpl in H0.
- rewrite H5 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 (grad / d)%R); intros.
- rewrite H6 in H0; simpl in H0.
- + specialize (IHdf2 (- d0 / (d*d) * grad)%R d3 d4).
- replace (c * (grad / d))%R with (c * grad / d )%R in IHdf1 by lra.
- rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1.
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1).
- case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2).
- case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros.
- rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d3 (- d0 /(d * d) * (c * grad))%R); [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d4 (- d0 / (d * d) * grad)%R); [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d3 d7 d2 d5) by trivial.
- rewrite (split_subvar d4 d8 d1 d6) by trivial.
- replace (c * (-d0 / (d * d) * grad))%R with (- d0/(d * d) * (c*grad))%R in IHdf2 by lra.
- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d3 d2) = Some (c * subvar (s, DTfloat) d4 d1)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d7 d5) = Some (c * subvar (s, DTfloat) d8 d6)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H13; inversion H14.
- rewrite H16; rewrite H17; lra.
- + now rewrite H6 in H0.
- + rewrite H2 in H; rewrite H1 in H.
- now rewrite H5 in H.
- - Case "Square"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- case_eq (df_eval σ df); [ | tauto]; intros.
- specialize (IHdf (2 * d1 * grad)%R grad_env1 grad_env2); simpl in *.
- rewrite H1 in IHdf; rewrite H2 in IHdf.
- replace (2 * d1 * (c * grad))%R with (c * (2 * d1 * grad))%R by lra.
- rewrite H3 in H; rewrite H3 in H0; simpl in *.
- replace (2 * d1 * (c * grad))%R with (c * (2 * d1 * grad))%R in H by lra.
- now apply IHdf.
- - Case "Exp"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- case_eq (df_eval σ df); [ | tauto]; intros.
- specialize (IHdf (grad * exp d1)%R grad_env1 grad_env2); simpl in *.
- replace (c * grad * exp d1)%R with (c * (grad * exp d1))%R by lra.
- rewrite H1 in IHdf; rewrite H2 in IHdf.
- rewrite H3 in H; rewrite H3 in H0.
- replace (c * grad * exp d1)%R with (c * (grad * exp d1))%R in H by lra.
- now apply IHdf.
- - Case "Log"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- case_eq (df_eval σ df); [ | tauto]; intros.
- specialize (IHdf (grad / d1)%R grad_env1 grad_env2 ); simpl in *.
- replace (c * grad / d1)%R with (c * (grad / d1))%R by lra.
- rewrite H1 in IHdf; rewrite H2 in IHdf.
- rewrite H3 in H; rewrite H3 in H0.
- replace (c * grad / d1)%R with (c * (grad / d1))%R in H by lra.
- now apply IHdf.
- - Case "Abs"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- case_eq (df_eval σ df); [ | tauto]; intros.
- specialize (IHdf (grad * sign d1)%R grad_env1 grad_env2); simpl in *.
- replace (c * grad * (@sign floatish_R d1))%R
- with (c * (grad * (@sign floatish_R d1)))%R by lra.
- rewrite H1 in IHdf; rewrite H2 in IHdf.
- rewrite H3 in H; rewrite H3 in H0.
- replace (c * grad * (@sign floatish_R d1))%R
- with (c * (grad * (@sign floatish_R d1)))%R in H by lra.
- now apply IHdf.
- - Case "Sign"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- specialize (IHdf (0)%R grad_env1 grad_env2); simpl in *.
- replace (0%R) with (c * 0)%R at 1 by lra.
- rewrite H1 in IHdf; rewrite H2 in IHdf.
- replace (0%R) with (c * 0)%R in H by lra.
- now apply IHdf.
- - Case "PSign"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- specialize (IHdf (0)%R grad_env1 grad_env2); simpl in *.
- replace (0%R) with (c * 0)%R at 1 by lra.
- rewrite H1 in IHdf; rewrite H2 in IHdf.
- replace (0%R) with (c * 0)%R in H by lra.
- now apply IHdf.
- - Case "Max"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- case_eq (df_eval σ df1); [ | tauto]; intros.
- case_eq (df_eval σ df2); [ | tauto]; intros.
- rewrite H3 in H; rewrite H3 in H0.
- rewrite H4 in H; rewrite H4 in H0.
- case_eq (Rle_dec d1 d2); intros.
- + specialize (IHdf2 grad grad_env1 grad_env2); simpl in *.
- rewrite H1 in IHdf2; rewrite H2 in IHdf2.
- rewrite H5 in H; rewrite H5 in H0.
- now apply IHdf2.
- + specialize (IHdf1 grad grad_env1 grad_env2); simpl in *.
- rewrite H1 in IHdf1; rewrite H2 in IHdf1.
- rewrite H5 in H; rewrite H5 in H0.
- now apply IHdf1.
- - Case "VectorDot"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- case_eq (df_eval σ df1); [ | tauto]; intros.
- case_eq (df_eval σ df2); [ | tauto]; intros.
- specialize (IHdf1 (vmap (fun rv => (rv * grad)%R) d2) grad_env1 grad_env2).
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1
- (vmap (fun rv : R => (rv * (c * grad))%R) d2)); intros.
- rewrite H3 in H; rewrite H4 in H; rewrite H5 in H; simpl in H.
- rewrite H3 in H0; rewrite H4 in H0; simpl in H0.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2
- (vmap (fun rv : R => (rv * grad)%R) d2)); intros.
- rewrite H6 in H0; simpl in H0.
- + specialize (IHdf2 (vmap (fun lv => (lv *grad)%R) d1) d3 d4).
- replace (fun i => (c * vmap (fun rv : R => rv * grad) d2 i)%R) with
- (vmap (fun rv : R => (rv * (c * grad))%R) d2) in IHdf1.
- rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1.
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1).
- case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2).
- case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros.
- rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d3
- (vmap (fun lv : R => (lv * (c * grad))%R) d1))
- ; [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d4 (vmap (fun lv : R => (lv * grad)%R) d1))
- ; [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d3 d7 d d5) by trivial.
- rewrite (split_subvar d4 d8 d0 d6) by trivial.
- replace
- (fun i : {n' : nat | n' < n} => (c * vmap (fun lv : R => lv * grad) d1 i)%R) with
- (vmap (fun lv : R => (lv * (c * grad))%R) d1) in IHdf2.
- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d3 d) = Some (c * subvar (s, DTfloat) d4 d0)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d7 d5) = Some (c * subvar (s, DTfloat) d8 d6)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H13; inversion H14.
- rewrite H16; rewrite H17; lra.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vmap_mult.
- assert ((fun lv => (lv * (c * grad))%R) = (fun x0 => (c * (x0 * grad))%R)).
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- now rewrite H13.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vmap_mult.
- assert ((fun rv => (rv * (c * grad))%R) = (fun x0 => (c * (x0 * grad))%R)).
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- now rewrite H7.
- + now rewrite H6 in H0.
- + rewrite H3 in H; rewrite H4 in H.
- now rewrite H5 in H.
- - Case "VectorSum"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- specialize (IHdf (ConstVector n grad) grad_env1 grad_env2).
- rewrite H1 in IHdf; rewrite H2 in IHdf; simpl in IHdf.
- replace (ConstVector n (c * grad)%R) with
- (fun i => (c * ConstVector n grad i)%R).
- now apply IHdf.
- now unfold ConstVector.
- - Case "MatrixSum"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- specialize (IHdf (ConstMatrix m n grad) grad_env1 grad_env2).
- rewrite H1 in IHdf; rewrite H2 in IHdf; simpl in IHdf.
- replace (ConstMatrix m n (c * grad)%R) with
- (fun i j => (c * ConstMatrix m n grad i j)%R).
- now apply IHdf.
- now unfold ConstMatrix.
- - Case "VectorElem"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- specialize (IHdf (fun k =>
- if equiv_dec (` k) (` i) then grad else 0%R) grad_env1 grad_env2).
- rewrite H1 in IHdf; rewrite H2 in IHdf; simpl in *.
- replace (fun i0 => (c * (if equiv_dec (` i0) (` i) then grad else 0))%R) with
- (fun k : {n' : nat | n' < n} =>
- if equiv_dec (` k) (` i) then (c * grad)%R else 0%R) in IHdf.
- now rewrite IHdf.
- apply FunctionalExtensionality.functional_extensionality; intros.
- destruct (equiv_dec (` x) (` i)); lra.
- - Case "MatrixElem"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- specialize (IHdf (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) =>
- if equiv_dec (` k1) (` i)
- then if equiv_dec (` k2) (` j) then grad else 0%R else 0%R)
- grad_env1 grad_env2).
- rewrite H1 in IHdf; rewrite H2 in IHdf; simpl in *.
- replace (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) =>
- if equiv_dec (` k1) (` i)
- then if equiv_dec (` k2) (` j) then (c * grad)%R else 0%R else 0%R) with
- (fun (i0 : {n' : nat | n' < m}) (j0 : {m' : nat | m' < n}) =>
- (c *
- (if equiv_dec (` i0) (` i)
- then if equiv_dec (` j0) (` j) then grad else 0
- else 0))%R) in *.
- + now rewrite IHdf.
- + apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- destruct (equiv_dec (` x) (` i)); [|lra].
- destruct (equiv_dec (` x0) (` j)); lra.
- - Case "MatrixVectorMult"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- case_eq ( df_eval σ df1); [|tauto].
- case_eq ( df_eval σ df2); [|tauto]; intros.
- specialize (IHdf1 (fun i j => (grad i * d j)%R) grad_env1 grad_env2).
- rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1
- (fun i j => (c * grad i * d j)%R)); intros.
- rewrite H1 in H; rewrite H2 in H; simpl in H.
- rewrite H1 in H0; rewrite H2 in H0; simpl in H0.
- rewrite H5 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2
- (fun i j => (grad i * d j)%R)); intros.
- rewrite H6 in H0; simpl in H0.
- + specialize (IHdf2 (matrix_vector_mult (fun i j => d0 j i) grad) d3 d4).
- replace
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (c * (grad i * d j))%R) with
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (c * grad i * d j)%R) in IHdf1.
- * rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1.
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1).
- case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2).
- case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros.
- rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- unfold lift; match_case; [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d4
- (matrix_vector_mult (fun i j => d0 j i) grad))
- ; [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d3 d7 d2 d5) by trivial.
- rewrite (split_subvar d4 d8 d1 d6) by trivial.
- replace (fun i : {n' : nat | n' < n} =>
- (c *
- (@matrix_vector_mult floatish_R _ _
- (fun (i0 : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) =>
- d0 j i0) grad) i)%R) with
- (matrix_vector_mult
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => d0 j i)
- (fun i : {n' : nat | n' < m} => (c * grad i)%R)) in IHdf2.
- -- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d3 d2) =
- Some (c * subvar (s, DTfloat) d4 d1)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d7 d5) =
- Some (c * subvar (s, DTfloat) d8 d6)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H13; inversion H14.
- rewrite H16; rewrite H17; lra.
- -- unfold matrix_vector_mult.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vsum_mult; f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- simpl; lra.
- * apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- + now rewrite H6 in H0.
- + rewrite H2 in H; rewrite H1 in H.
- now rewrite H5 in H.
- - Case "MatrixVectorAdd"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- specialize (IHdf1 grad grad_env1 grad_env2); intros.
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i j => (c * grad i j)%R))
- ; intros.
- + rewrite H3 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros.
- * rewrite H4 in H0; simpl in H0.
- match_option_in H0; [|tauto].
- match_option_in H; [|tauto].
- rewrite H3 in IHdf1.
- unfold lift.
- f_equal.
- rewrite H4 in IHdf1.
- unfold lift in IHdf1.
- assert (vartlookup d1 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1).
- case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2).
- case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros.
- rewrite (split_subvar d1 d4 d0 d5) by trivial.
- rewrite (split_subvar d2 d3 d d6) by trivial.
- assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R)
- by (apply IHdf1; trivial; discriminate).
- inversion H9; rewrite H11.
- assert (Some (subvar (s, DTfloat) d4 d5) = Some (c * subvar (s, DTfloat) d3 d6)%R).
- -- f_equal.
- apply (scalarMult_list_env_iter
- s c d5 d6 {m' : nat | m' < n}
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv
- σ df2 env
- (transpose
- (fun (i0 : {n' : nat | n' < m})
- (j : {m' : nat | m' < n}) =>
- (c * grad i0 j)%R) i))
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad i))
- (bounded_seq0 n) d1 d2 d4 d3); trivial.
- ++ intros.
- assert (vartlookup env1 (s, DTfloat) <> None); [congruence|].
- assert (vartlookup env2 (s, DTfloat) <> None); [congruence|].
- assert (df_eval_backprop_deriv
- σ df2 env1
- (transpose
- (fun (i0 : {n' : nat | n' < m})
- (j : {m' : nat | m' < n}) => (c * grad i0 j)%R) i) <> None)
- ; [congruence|].
- assert (df_eval_backprop_deriv σ df2 env2 (transpose grad i) <> None)
- ;[congruence|].
- specialize (IHdf2 (transpose grad i) env1 env2).
- specialize (IHdf2 H15 H16).
- specialize (IHdf2 H17 H18).
- unfold lift in IHdf2; simpl in IHdf2.
- rewrite H10, H12, H14 in IHdf2.
- unfold transpose in IHdf2; unfold transpose in H13.
- rewrite H13 in IHdf2; now inversion IHdf2.
- ++ intros.
- now apply (df_eval_backprop_deriv_preserves_lookup_not_none H10).
- ++ intros.
- now apply (df_eval_backprop_deriv_preserves_lookup_not_none H10).
- -- inversion H10; rewrite H13; lra.
- * rewrite H4 in H0; tauto.
- + rewrite H3 in H; tauto.
- - Case "MatrixMult"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- case_eq ( df_eval σ df1); [|tauto].
- case_eq ( df_eval σ df2); [|tauto]; intros.
- specialize (IHdf1 (matrix_mult grad (fun i j => d j i)) grad_env1 grad_env2).
- rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1
- (matrix_mult (fun i j => (c * grad i j)%R)
- (fun i j => d j i))); intros.
- rewrite H1 in H; rewrite H2 in H; simpl in H.
- rewrite H1 in H0; rewrite H2 in H0; simpl in H0.
- rewrite H5 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2
- (matrix_mult grad (fun i j => d j i))); intros.
- rewrite H6 in H0; simpl in H0.
- + specialize (IHdf2 (matrix_mult (fun i j => d0 j i) grad) d3 d4).
- replace (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < p}) =>
- (c *
- (@matrix_mult floatish_R m n p grad
- (fun (i0 : {n' : nat | (n' < n)%nat}) (j0 : {m' : nat | (m' < p)%nat}) =>
- d j0 i0)) i j)%R) with
- (matrix_mult (fun i j => (c * grad i j)%R)
- (fun i j => d j i)) in IHdf1.
- * rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1.
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1).
- case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2).
- case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros.
- rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- unfold lift; match_case; [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d4 (matrix_mult (fun i j => d0 j i) grad))
- ; [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d3 d7 d2 d5) by trivial.
- rewrite (split_subvar d4 d8 d1 d6) by trivial.
- replace (fun (i : {n' : nat | n' < p}) (j : {m' : nat | m' < n}) =>
- (c *
- (@matrix_mult floatish_R p m n
- (fun (i0 : {n' : nat | (n' < p)%nat})
- (j0 : {m' : nat | (m' < m)%nat}) =>
- d0 j0 i0) grad) i j)%R) with
- (matrix_mult (fun (i : {n' : nat | n' < p}) (j : {m' : nat | m' < m}) => d0 j i)
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (c * grad i j)%R)) in IHdf2.
- -- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d3 d2) =
- Some (c * subvar (s, DTfloat) d4 d1)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d7 d5) =
- Some (c * subvar (s, DTfloat) d8 d6)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H13; inversion H14.
- rewrite H16; rewrite H17; lra.
- -- unfold matrix_mult.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vsum_mult; f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- simpl; lra.
- * unfold matrix_mult.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vsum_mult; f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- simpl; lra.
- + now rewrite H6 in H0.
- + rewrite H2 in H; rewrite H1 in H.
- now rewrite H5 in H.
- - Case "VectorPlus"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- specialize (IHdf1 grad grad_env1 grad_env2); intros.
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i => (c * grad i)%R)); intros.
- rewrite H3 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros.
- rewrite H4 in H0; simpl in H0.
- + specialize (IHdf2 grad d1 d2).
- rewrite H3 in IHdf1; rewrite H4 in IHdf1.
- assert (vartlookup d1 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1).
- case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2).
- case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros.
- rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d1 (fun i => (c * grad i)%R))
- ; [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d2 grad); [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d1 d5 d0 d3) by trivial.
- rewrite (split_subvar d2 d6 d d4) by trivial.
- rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H11; inversion H12.
- rewrite H14; rewrite H15; lra.
- + now rewrite H4 in H0.
- + now rewrite H3 in H.
- - Case "VectorMinus"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- specialize (IHdf1 grad grad_env1 grad_env2); intros.
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i => (c * grad i )%R))
- ; intros.
- rewrite H3 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros.
- rewrite H4 in H0; simpl in H0.
- + specialize (IHdf2 (fun i => (- grad i)%R) d1 d2).
- rewrite H3 in IHdf1; rewrite H4 in IHdf1.
- assert (vartlookup d1 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1).
- case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2).
- case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros.
- rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d1 (fun i => (- (c * grad i))%R))
- ; [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d2 (fun i => (- grad i)%R)); [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d1 d5 d0 d3) by trivial.
- rewrite (split_subvar d2 d6 d d4) by trivial.
- replace (fun i => (c * - grad i)%R) with (fun i => (-( c * grad i))%R) in IHdf2.
- * rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H11; inversion H12.
- rewrite H14; rewrite H15; lra.
- * apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- + now rewrite H4 in H0.
- + now rewrite H3 in H.
- - Case "MatrixPlus"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- specialize (IHdf1 grad grad_env1 grad_env2); intros.
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i j => (c * grad i j)%R))
- ; intros.
- rewrite H3 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros.
- rewrite H4 in H0; simpl in H0.
- + specialize (IHdf2 grad d1 d2).
- rewrite H3 in IHdf1; rewrite H4 in IHdf1.
- assert (vartlookup d1 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1).
- case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2).
- case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros.
- rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d1 (fun i j => (c * grad i j)%R))
- ; [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d2 grad); [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d1 d5 d0 d3) by trivial.
- rewrite (split_subvar d2 d6 d d4) by trivial.
- rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H11; inversion H12.
- rewrite H14; rewrite H15; lra.
- + now rewrite H4 in H0.
- + now rewrite H3 in H.
- - Case "MatrixMinus"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- specialize (IHdf1 grad grad_env1 grad_env2); intros.
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1 (fun i j => (c * grad i j)%R))
- ; intros.
- rewrite H3 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); intros.
- rewrite H4 in H0; simpl in H0.
- + specialize (IHdf2 (fun i j => (- grad i j)%R) d1 d2).
- rewrite H3 in IHdf1; rewrite H4 in IHdf1.
- assert (vartlookup d1 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H3 (s, DTfloat) neq1).
- case_eq (vartlookup d1 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2).
- case_eq (vartlookup d2 (s, DTfloat)); [ |tauto]; intros.
- rewrite H6 in IHdf2; rewrite H8 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d1 (fun i j => (- (c * grad i j))%R))
- ; [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d2 (fun i j => (- grad i j)%R)); [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d1 d5 d0 d3) by trivial.
- rewrite (split_subvar d2 d6 d d4) by trivial.
- replace (fun i j => (c * - grad i j)%R) with
- (fun i j => (-( c * grad i j))%R) in IHdf2.
- * rewrite H9 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d1 d0) = Some (c * subvar (s, DTfloat) d2 d)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d5 d3) = Some (c * subvar (s, DTfloat) d6 d4)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H11; inversion H12.
- rewrite H14; rewrite H15; lra.
- * apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- + now rewrite H4 in H0.
- + now rewrite H3 in H.
- - Case "VectorScalMult"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- case_eq ( df_eval σ df1); [|tauto].
- case_eq ( df_eval σ df2); [|tauto]; intros.
- specialize (IHdf1 (vsum (fun j => (d j * grad j)%R)) grad_env1 grad_env2).
- rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1
- (vsum (fun j => (d j * (c * grad j))%R)))
-
- ; intros.
- rewrite H1 in H; rewrite H2 in H; simpl in H.
- rewrite H1 in H0; rewrite H2 in H0; simpl in H0.
- rewrite H5 in H; simpl in H.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2
- (vsum (fun j => (d j * grad j)%R)))
- ; intros.
- rewrite H6 in H0; simpl in H0.
- + specialize (IHdf2 (fun j => (grad j * d0)%R) d3 d4).
- replace
- (c *
- (@vsum floatish_R _
- (fun (j : {n' : nat | (n' < n)%nat}) =>
- d j * grad j)))%R with
- (vsum
- (fun (j : {n' : nat | n' < n}) =>
- (d j * (c * grad j))%R)) in IHdf1.
- * rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1.
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1).
- case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2).
- case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros.
- rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d3
- (fun (j : {n' : nat | n' < n}) =>
- (d0 * (c * grad j))%R))
- ; [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d4 (fun j => (d0 * grad j)%R))
- ; [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d3 d7 d2 d5) by trivial.
- rewrite (split_subvar d4 d8 d1 d6) by trivial.
- replace (fun i => (c * (grad i * d0))%R) with
- (fun j => (d0 * (c * grad j))%R) in IHdf2.
- replace (fun j => (grad j * d0)%R) with (fun j => (d0 * grad j)%R) in IHdf2.
- -- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d3 d2) =
- Some (c * subvar (s, DTfloat) d4 d1)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d7 d5) =
- Some (c * subvar (s, DTfloat) d8 d6)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H13; inversion H14.
- rewrite H16; rewrite H17; lra.
- -- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- -- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- * rewrite vsum_mult; f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- + now rewrite H6 in H0.
- + rewrite H2 in H; rewrite H1 in H.
- now rewrite H5 in H.
- - Case "MatrixScalMult"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto].
- case_eq ( df_eval σ df1); [|tauto].
- case_eq ( df_eval σ df2); [|tauto]; intros.
- specialize (IHdf1 (msum (fun i j => (d i j * grad i j)%R)) grad_env1 grad_env2).
- rewrite H4 in IHdf1; rewrite H3 in IHdf1; simpl in *.
- case_eq (df_eval_backprop_deriv σ df1 grad_env1
- (msum (fun i j => (d i j * (c * grad i j))%R)))
-
- ; intros.
- rewrite H1, H2, H5 in H; simpl in H.
- rewrite H1, H2 in H0; simpl in H0.
- case_eq (df_eval_backprop_deriv σ df1 grad_env2
- (msum (fun i j => (d i j * grad i j)%R)))
- ; intros.
- rewrite H6 in H0; simpl in H0.
- + specialize (IHdf2 (fun i j => (grad i j * d0)%R) d3 d4).
- replace
- (c *
- (@msum floatish_R _ _
- (fun (i : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) =>
- d i j * grad i j)))%R with
-
- (msum
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (d i j * (c * grad i j))%R)) in IHdf1.
- * rewrite H5 in IHdf1; rewrite H6 in IHdf1; simpl in IHdf1.
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 (s, DTfloat) neq1).
- case_eq (vartlookup d3 (s, DTfloat)); [ |tauto]; intros.
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H6 (s, DTfloat) neq2).
- case_eq (vartlookup d4 (s, DTfloat)); [ |tauto]; intros.
- rewrite H8 in IHdf2; rewrite H10 in IHdf2; simpl in *.
- case_eq (df_eval_backprop_deriv σ df2 d3
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (c * grad i j * d0)%R))
- ; [|tauto]; intros.
- case_eq (df_eval_backprop_deriv σ df2 d4 (fun i j => (grad i j * d0)%R))
- ; [|tauto]; intros; simpl; f_equal.
- rewrite (split_subvar d3 d7 d2 d5) by trivial.
- rewrite (split_subvar d4 d8 d1 d6) by trivial.
- replace (fun i j => (c * (grad i j * d0))%R) with
- (fun i j => (c * grad i j * d0)%R) in IHdf2.
- -- rewrite H11 in IHdf2; rewrite H12 in IHdf2; simpl in *.
- assert (Some (subvar (s, DTfloat) d3 d2) =
- Some (c * subvar (s, DTfloat) d4 d1)%R) by
- (apply IHdf1; trivial; discriminate).
- assert (Some (subvar (s, DTfloat) d7 d5) =
- Some (c * subvar (s, DTfloat) d8 d6)%R) by
- (apply IHdf2; trivial; discriminate).
- inversion H13; inversion H14.
- rewrite H16; rewrite H17; lra.
- -- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- * unfold msum.
- rewrite vsum_mult; f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vmap_nth.
- rewrite vmap_nth.
- rewrite vsum_mult.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- + now rewrite H6 in H0.
- + rewrite H2 in H; rewrite H1 in H.
- now rewrite H5 in H.
- - Case "VectorApply"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- simpl in *.
- case_eq (df_eval σ df2); [ | tauto].
- intros.
- rewrite H3 in H; rewrite H3 in H0; simpl in *.
- match_option_in H0; [|tauto].
- specialize (IHdf1 v0 grad_env1 grad_env2).
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in IHdf1.
- match_option_in H; [|tauto].
- vectoro_assert_forall_in eqq i.
- apply vectoro_to_ovector_forall_some_f; trivial.
- vectoro_assert_forall_in eqq0 i.
- apply vectoro_to_ovector_forall_some_f; trivial.
- assert (v1 = (fun i => (c * v0 i)%R)).
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H4 x).
- specialize (H5 x).
- rewrite vmap_nth in H4; simpl in H4.
- rewrite vmap_nth in H5; simpl in H5.
- match_option_in H4.
- match_option_in H5.
- inversion H4; inversion H5; subst.
- assert (Some d2 = Some d3).
- rewrite <- eqq1.
- rewrite <- eqq2; trivial.
- inversion H6; subst; lra.
- subst.
- apply IHdf1; trivial; discriminate.
- - Case "MatrixApply"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- simpl in *.
- case_eq (df_eval σ df2); [ | tauto].
- intros.
- rewrite H3 in H; rewrite H3 in H0; simpl in *.
- match_option_in H0; [|tauto].
- match_option_in H; [|tauto].
- unfold matrixo_to_omatrix in eqq.
- unfold matrixo_to_omatrix in eqq0.
- vectoro_assert_forall_in eqq i.
- apply vectoro_to_ovector_forall_some_f; trivial.
- vectoro_assert_forall_in eqq0 i.
- apply vectoro_to_ovector_forall_some_f; trivial.
- assert (m1 = (fun i j => (c * m0 i j)%R)).
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H4 x); specialize (H5 x); simpl in H4; simpl in H5.
- vectoro_assert_forall_in H4 j.
- apply vectoro_to_ovector_forall_some_f; trivial.
- vectoro_assert_forall_in H5 j.
- apply vectoro_to_ovector_forall_some_f; trivial.
- specialize (H6 x0); specialize (H7 x0).
- unfold mmap in H6; unfold mmap in H7.
- rewrite vmap_nth in H6; rewrite vmap_nth in H6; simpl in H6.
- rewrite vmap_nth in H7; rewrite vmap_nth in H7; simpl in H7.
- match_case_in H6; intros.
- rewrite H8 in H6; simpl in H6.
- match_case_in H7; intros.
- rewrite H9 in H7; simpl in H7.
- match_option_in H6.
- match_option_in H7.
- inversion H6; inversion H7.
- unfold matrix_zip in H8.
- unfold matrix_zip in H9.
- rewrite vmap_nth in H8.
- rewrite vmap_nth in H9.
- unfold vector_zip in H8.
- unfold vector_zip in H9.
- inversion H8; subst r r0.
- inversion H9; subst r1 r2.
- assert (Some d2 = Some d3).
- rewrite <- eqq1; rewrite <- eqq2; trivial.
- inversion H10; subst; lra.
- specialize (IHdf1 m0 grad_env1 grad_env2).
- rewrite H1, H2 in IHdf1; simpl in IHdf1.
- subst m1.
- apply IHdf1; trivial; discriminate.
- - Case "VLossfun"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- simpl in *.
- case_eq (df_eval σ df2); [ | tauto].
- intros.
- rewrite H3 in H; rewrite H3 in H0; simpl in *.
- match_option_in H0; [|tauto].
- match_option_in H; [|tauto].
- vectoro_assert_forall_in eqq i.
- apply vectoro_to_ovector_forall_some_f; trivial.
- vectoro_assert_forall_in eqq0 i.
- apply vectoro_to_ovector_forall_some_f; trivial.
- assert (v0 = (fun i => (c * v i)%R)).
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H4 x); specialize (H5 x).
- rewrite vmap_nth in H4; simpl in H4.
- rewrite vmap_nth in H5; simpl in H5.
- match_option_in H4.
- match_option_in H5.
- assert (Some d2 = Some d3).
- rewrite <- eqq1; rewrite <- eqq2; trivial.
- inversion H6; subst.
- inversion H4; inversion H5; lra.
- subst.
- specialize (IHdf1 v grad_env1 grad_env2).
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in IHdf1.
- apply IHdf1; trivial; discriminate.
- - Case "MLossfun"%string.
- case_eq (vartlookup grad_env1 (s, DTfloat)); [ |tauto]; intros.
- case_eq (vartlookup grad_env2 (s, DTfloat)); [ |tauto]; intros.
- simpl in *.
- case_eq (df_eval σ df2); [ | tauto].
- intros.
- rewrite H3 in H; rewrite H3 in H0; simpl in *.
- match_option_in H0; [|tauto].
- match_option_in H; [|tauto].
- unfold matrixo_to_omatrix in eqq.
- unfold matrixo_to_omatrix in eqq0.
- vectoro_assert_forall_in eqq i.
- apply vectoro_to_ovector_forall_some_f; trivial.
- vectoro_assert_forall_in eqq0 i.
- apply vectoro_to_ovector_forall_some_f; trivial.
- assert (m1 = (fun i j => (c * m0 i j)%R)).
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H4 x); specialize (H5 x); simpl in H4; simpl in H5.
- vectoro_assert_forall_in H4 j.
- apply vectoro_to_ovector_forall_some_f; trivial.
- vectoro_assert_forall_in H5 j.
- apply vectoro_to_ovector_forall_some_f; trivial.
- specialize (H6 x0); specialize (H7 x0).
- unfold mmap in H6; unfold mmap in H7.
- rewrite vmap_nth in H6; rewrite vmap_nth in H6; simpl in H6.
- rewrite vmap_nth in H7; rewrite vmap_nth in H7; simpl in H7.
- match_destr_in H6.
- match_option_in H6.
- match_option_in H7.
- assert (Some d2 = Some d3).
- rewrite <- eqq1; rewrite <- eqq2; trivial.
- inversion H8; subst.
- inversion H6; inversion H7.
- lra.
- rewrite H6.
- specialize (IHdf1 m0 grad_env1 grad_env2).
- rewrite H1 in IHdf1; rewrite H2 in IHdf1; simpl in IHdf1.
- subst.
- apply IHdf1; trivial; discriminate.
- Qed.
-
- Ltac simpl_closed_backprop :=
- match goal with
- | [|- context [
- match df_eval_backprop_deriv ?σ ?df1 ?grad_env1 ?grad with
- | Some _ => _
- | None => _
- end]] => case_eq (df_eval_backprop_deriv σ df1 grad_env1 grad)
- ; [let env := fresh "env" in let eqq := fresh "eqq" in intros env eqq |
- let eqq := fresh "eqq" in
- intros eqq;
- eelim backprop_deriv_fully_closed_not_none; [clear eqq | eapply eqq]; trivial
- ]
- end.
-
- Ltac simpler2 :=
- trivial;
- repeat
- match goal with
- | [ |- Some _ <> None ] => congruence
- | [ |- None <> Some _ ] => congruence
-
- | [H:vartlookup ?grad_env ?a <> None
- |- context [vartlookup ?grad_env ?a]] =>
- case_eq (vartlookup grad_env a); [
- let d := fresh "val" in
- let eqq := fresh "eqq" in
- intros d eqq; simpl in d
- | intros ?; eelim H; solve[auto]]
- | [H: df_eval_backprop_deriv ?σ ?df1 ?grad_env1 _ = Some ?grad_env2
- |- context [vartlookup ?grad_env2 ?a]] =>
- case_eq (vartlookup grad_env2 a); [
- let d := fresh "val" in
- let eqq := fresh "eqq" in
- intros d eqq; simpl in d
- | let eqq := fresh "eqq" in
- intros eqq; eelim df_eval_backprop_deriv_preserves_lookup_not_none; [apply H | idtac | apply eqq]; solve[auto]
- ]
-
- | [H:vartlookup ?grad_env ?a <> None
- |- context [vartlookup ?grad_env ?a]] =>
- case_eq (vartlookup grad_env a); [
- let d := fresh "val" in
- let eqq := fresh "eqq" in
- intros d eqq; simpl in d
- | intros ?; eelim H; solve[auto || congruence]]
- | [H: df_eval_backprop_deriv ?σ ?df1 ?grad_env1 _ = Some ?grad_env2
- |- context [vartlookup ?grad_env2 ?a]] =>
- case_eq (vartlookup grad_env2 a); [
- let d := fresh "val" in
- let eqq := fresh "eqq" in
- intros d eqq; simpl in d
- | let eqq := fresh "eqq" in
- intros eqq; eelim df_eval_backprop_deriv_preserves_lookup_not_none; [apply H | idtac | apply eqq]; solve[auto || congruence]
- ]
- | [H:vartlookup ?grad_env ?a <> None,
- H2:context [match vartlookup ?grad_env ?a with | _ => _ end] |- _] =>
- case_eq (vartlookup grad_env a); [
- let d := fresh "val" in
- let eqq := fresh "eqq" in
- intros d eqq; simpl in d; rewrite eqq in H2
- | intros ?; eelim H; solve[auto || congruence]]
- | [H: df_eval_backprop_deriv ?σ ?df1 ?grad_env1 _ = Some ?grad_env2,
- H2: context [match vartlookup ?grad_env2 ?a with _ => _ end] |- _] =>
- case_eq (vartlookup grad_env2 a); [
- let d := fresh "val" in
- let eqq := fresh "eqq" in
- intros d eqq; simpl in d; rewrite eqq in H2
- | let eqq := fresh "eqq" in
- intros eqq; eelim df_eval_backprop_deriv_preserves_lookup_not_none; [apply H | idtac | apply eqq]; solve[auto || congruence]]
-
- end.
-
- Lemma backprop_indep_env {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (s:SubVar)
- (grad_env1 grad_env2:df_env) (grad : definition_function_types_interp T) :
- let v := (s, DTfloat) in
- vartlookup grad_env1 v <> None ->
- vartlookup grad_env2 v <> None ->
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl ->
- df_eval_backprop_delta σ df v grad_env1 grad =
- df_eval_backprop_delta σ df v grad_env2 grad.
- Proof.
- revert grad_env1 grad_env2.
- unfold df_eval_backprop_delta.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case
- ; simpl; intros grad_env1 grad_env2 neq1 neq2; intros.
- - Case "Number"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- unfold snd in *.
- f_equal.
- unfold subvar; simpl.
- rewrite eqq, eqq0.
- lra.
- - Case "Constant"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- unfold snd in *.
- f_equal.
- unfold subvar; simpl.
- rewrite eqq, eqq0.
- lra.
- - Case "DVector"%string.
- rewrite vforall_forall in H0.
- unfold two_vector_env_iter_alt.
- match_option; [|tauto].
- match_option; [|tauto].
- unfold lift; simpl.
- match_option.
- + match_option.
- f_equal.
- apply (list_env_iter_subvar_env2
- s d d0 {n' : nat | n' < n}
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (grad i))
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (grad i))
- (bounded_seq0 n) grad_env1 grad_env2 d1 d2); trivial.
- * intros.
- specialize (H i (grad i) env1 env2).
- cut_to H; try congruence; eauto 3.
- rewrite H1, H2, H3, H4 in H.
- unfold lift in H.
- now inversion H.
- * intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H1); trivial.
- * intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H1); trivial.
- * assert (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (grad i)) (Some grad_env2)
- (bounded_seq0 n) <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- + assert (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (grad i)) (Some grad_env1)
- (bounded_seq0 n) <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- - Case "DMatrix"%string.
- rewrite vforall_forall in H0.
- unfold two_matrix_env_iter_alt.
- match_option; [|tauto].
- match_option; [|tauto].
- unfold lift; simpl.
- match_option.
- + match_option.
- * f_equal.
- apply (list_env_iter_subvar_env2
- s d d0 {n' : nat | n' < n}
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (Some env) (bounded_seq0 m))
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (Some env) (bounded_seq0 m))
- (bounded_seq0 n) grad_env1 grad_env2 d1 d2); trivial.
- -- intros.
- apply (list_env_iter_subvar_env2
- s v1 v2 {m' : nat | m' < m}
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (bounded_seq0 m) env1 env2 fenv1 fenv2); trivial.
- ++ intros.
- specialize (H i i0 (grad i i0) env0 env3).
- cut_to H.
- rewrite H5, H6, H7, H8 in H.
- unfold lift in H.
- now inversion H.
- congruence.
- congruence.
- specialize (H0 i).
- rewrite vforall_forall in H0.
- apply (H0 i0).
- ++ intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5); trivial.
- ++ intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5); trivial.
- -- intros.
- apply (vartlookup_list_env_iter
- s
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (bounded_seq0 m) env fenv); trivial.
- intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H3); trivial.
- -- intros.
- apply (vartlookup_list_env_iter
- s
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (bounded_seq0 m) env fenv); trivial.
- intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H3); trivial.
- * assert (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (Some env) (bounded_seq0 m)) (Some grad_env2) (bounded_seq0 n)
- <> None).
- apply list_env_iter_total_fun; intros.
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- specialize (H0 a).
- rewrite vforall_forall in H0.
- apply (H0 a0).
- tauto.
- + assert (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad i j))
- (Some env) (bounded_seq0 m)) (Some grad_env1) (bounded_seq0 n)
- <> None).
- apply list_env_iter_total_fun; intros.
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- specialize (H0 a).
- rewrite vforall_forall in H0.
- apply (H0 a0).
- tauto.
- - Case "Var"%string.
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- unfold lift, subvar; simpl.
- destruct (v == (s,DTfloat)).
- + invcs e.
- unfold addvar; simpl.
- rewrite eqq, H0.
- rewrite lookup_update.
- rewrite lookup_update.
- f_equal; lra.
- + assert (v<> (s, DTfloat)) by congruence.
- case_eq (vartlookup grad_env1 v); intros.
- * rewrite lookup_update_neq; trivial.
- rewrite eqq.
- case_eq (vartlookup grad_env2 v); intros.
- -- rewrite lookup_update_neq; trivial.
- rewrite H0; f_equal; lra.
- -- rewrite H0; f_equal; lra.
- * rewrite eqq.
- case_eq (vartlookup grad_env2 v); intros.
- -- rewrite lookup_update_neq; trivial.
- rewrite H0; f_equal; lra.
- -- rewrite H0; f_equal; lra.
- - Case "Plus"%string.
- destruct H.
- unfold lift.
- repeat simpl_closed_backprop.
- simpler2.
- f_equal.
- specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1,eqq3,eqq4 in IHdf1.
- specialize (IHdf2 grad env env1).
- cut_to IHdf2; simpler2.
- unfold lift in IHdf1; inversion IHdf1.
- rewrite eqq2, eqq0 in IHdf2; unfold lift in IHdf2; inversion IHdf2.
- rewrite (split_subvar env env0 val val1); trivial.
- rewrite (split_subvar env1 env2 val0 val2); trivial.
- rewrite H2, H3; lra.
- - Case "Minus"%string.
- destruct H.
- unfold lift.
- repeat simpl_closed_backprop.
- simpler2.
- f_equal.
- specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1,eqq3,eqq4 in IHdf1.
- specialize (IHdf2 (-grad)%R env env1).
- cut_to IHdf2; simpler2.
- unfold lift in IHdf1; invcs IHdf1.
- rewrite eqq0,eqq2 in IHdf2; unfold lift in IHdf2; invcs IHdf2.
- rewrite (split_subvar env env0 val val1); trivial.
- rewrite (split_subvar env1 env2 val0 val2); trivial.
- rewrite H2, H3; lra.
- - Case "Times"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df1 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- unfold lift.
- repeat simpl_closed_backprop.
- simpler2.
- f_equal.
- specialize (IHdf1 (d1 * grad)%R grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq6,eqq2,eqq4 in IHdf1.
- specialize (IHdf2 (d0 * grad)%R env env1).
- cut_to IHdf2; simpler2.
- unfold lift in IHdf1; invcs IHdf1.
- rewrite eqq5, eqq3 in IHdf2; unfold lift in IHdf2; invcs IHdf2.
- rewrite (split_subvar env env0 d val0); trivial.
- rewrite (split_subvar env1 env2 val val1); trivial.
- rewrite H4, H5; lra.
- - Case "Divide"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df1 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- unfold lift.
- repeat simpl_closed_backprop.
- simpler2.
- f_equal.
- specialize (IHdf1 (grad/d1)%R grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq6,eqq2,eqq4 in IHdf1.
- specialize (IHdf2 (-d0/(d1*d1) * grad)%R env env1).
- cut_to IHdf2; simpler2.
- unfold lift in IHdf1; inversion IHdf1.
- rewrite eqq5, eqq3 in IHdf2; unfold lift in IHdf2; inversion IHdf2.
- rewrite (split_subvar env env0 d val0); trivial.
- rewrite (split_subvar env1 env2 val val1); trivial.
- rewrite H4, H5; lra.
- - Case "Square"%string.
- match_option; [|tauto].
- assert (df_eval σ df <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- match_option; [|tauto].
- unfold lift.
- repeat simpl_closed_backprop.
- f_equal.
- specialize (IHdf (2 * d0 * grad)%R grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1,eqq2,eqq3 in IHdf.
- now unfold lift in IHdf; inversion IHdf.
- - Case "Exp"%string.
- match_option; [|tauto].
- assert (df_eval σ df <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- match_option; [|tauto].
- unfold lift.
- repeat simpl_closed_backprop.
- f_equal.
- specialize (IHdf (grad * exp d0)%R grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1,eqq2, eqq3 in IHdf.
- now unfold lift in IHdf; inversion IHdf.
- - Case "Log"%string.
- match_option; [|tauto].
- assert (df_eval σ df <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df grad_env1 (grad/d0)%R <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- unfold lift.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df grad_env2 (grad/d0)%R <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- specialize (IHdf (grad / d0)%R grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1 in IHdf.
- now rewrite eqq2, eqq3 in IHdf; unfold lift in IHdf; inversion IHdf.
- - Case "Abs"%string.
- match_option; [|tauto].
- assert (df_eval σ df <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df grad_env1 (grad * sign d0)%R <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- unfold lift.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df grad_env2 (grad * sign d0)%R <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- specialize (IHdf (grad * sign d0)%R grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1 in IHdf.
- now rewrite eqq2, eqq3 in IHdf; unfold lift in IHdf; inversion IHdf.
- - Case "Sign"%string.
- match_option; [|tauto].
- assert (df_eval σ df <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df grad_env1 (0)%R <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- unfold lift.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df grad_env2 (0)%R <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- specialize (IHdf (0)%R grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq0 in IHdf.
- now rewrite eqq1, eqq2 in IHdf; unfold lift in IHdf; inversion IHdf.
- - Case "PSign"%string.
- match_option; [|tauto].
- assert (df_eval σ df <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df grad_env1 (0)%R <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- unfold lift.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df grad_env2 (0)%R <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- specialize (IHdf (0)%R grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq0 in IHdf.
- now rewrite eqq1, eqq2 in IHdf; unfold lift in IHdf; inversion IHdf.
- - Case "Max"%string.
- destruct H.
- specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H).
- specialize (IHdf2 grad grad_env1 grad_env2 neq1 neq2 H0).
- match_option; [|tauto].
- assert (df_eval σ df1 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- match_option; [|tauto].
- rewrite eqq, eqq2 in IHdf1.
- rewrite eqq, eqq2 in IHdf2.
- destruct (Rle_dec d0 d1); trivial.
- - Case "VectorDot"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df1 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- specialize (IHdf1 (vmap (fun rv : R => (rv * grad)%R) d1)
- grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, H3 in IHdf1.
- assert (df_eval_backprop_deriv σ df1 grad_env1
- (vmap (fun rv : R => (rv * grad)%R) d1) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2
- (vmap (fun rv : R => (rv * grad)%R) d1) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1).
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2).
- specialize (IHdf2 (vmap (fun lv : R => (lv * grad)%R) d0)
- d3 d4 H6 H7 H0).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- unfold lift.
- assert (df_eval_backprop_deriv σ df2 d3
- (vmap (fun lv : R => (lv * grad)%R) d0) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d4
- (vmap (fun lv : R => (lv * grad)%R) d0) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- unfold lift in IHdf1.
- rewrite eqq2, eqq3 in IHdf1; inversion IHdf1.
- rewrite eqq6, eqq7 in IHdf2; inversion IHdf2.
- rewrite (split_subvar d3 d7 d d5); trivial.
- rewrite (split_subvar d4 d8 d2 d6); trivial.
- now rewrite H11, H12.
- - Case "VectorSum"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (ConstVector n grad) grad_env1 grad_env2 neq1 neq2 H).
- now rewrite eqq, eqq0 in IHdf.
- - Case "MatrixSum"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (ConstMatrix m n grad) grad_env1 grad_env2 neq1 neq2 H).
- now rewrite eqq, eqq0 in IHdf.
- - Case "VectorElem"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (fun k : {n' : nat | n' < n} => if equiv_dec (` k) (` i)
- then grad else 0%R)
- grad_env1 grad_env2 neq1 neq2 H).
- now rewrite eqq, eqq0 in IHdf.
- - Case "MatrixElem"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) =>
- if equiv_dec (` k1) (` i) then
- if equiv_dec (` k2) (` j) then grad else 0%R else 0%R)
- grad_env1 grad_env2 neq1 neq2 H).
- now rewrite eqq, eqq0 in IHdf.
- - Case "MatrixVectorMult"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df1 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- specialize (IHdf1 (fun (i : {n' : nat | n' < m})
- (j : {m' : nat | m' < n}) => (grad i * d1 j)%R)
- grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, H3 in IHdf1.
- assert (df_eval_backprop_deriv σ df1 grad_env1
- (fun (i : {n' : nat | n' < m})
- (j : {m' : nat | m' < n}) => (grad i * d1 j)%R)
- <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2
- (fun (i : {n' : nat | n' < m})
- (j : {m' : nat | m' < n}) => (grad i * d1 j)%R)
- <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1).
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2).
- specialize (IHdf2
- (matrix_vector_mult
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => d0 j i) grad)
- d3 d4 H6 H7 H0).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- unfold lift.
- assert (df_eval_backprop_deriv
- σ df2 d3
- (matrix_vector_mult (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => d0 j i)
- grad) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d4
- (matrix_vector_mult (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => d0 j i)
- grad) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- unfold lift in IHdf1.
- unfold lift in IHdf2.
- rewrite eqq2, eqq3 in IHdf1; inversion IHdf1.
- rewrite eqq6, eqq7 in IHdf2; inversion IHdf2.
- rewrite (split_subvar d3 d7 d d5); trivial.
- rewrite (split_subvar d4 d8 d2 d6); trivial.
- now rewrite H11, H12.
- - Case "MatrixVectorAdd"%string.
- destruct H.
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- specialize (IHdf1 grad
- grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, H1 in IHdf1.
- assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- case_eq (df_eval_backprop_deriv σ df1 grad_env2 grad); [intros | tauto].
- rewrite eqq0, H4 in IHdf1.
- unfold lift in IHdf1.
- inversion IHdf1.
- assert (list_env_iter
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad i)) (Some d1)
- (bounded_seq0 n) <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- match_option; [|tauto].
- assert (list_env_iter
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad i)) (Some d2)
- (bounded_seq0 n) <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- match_option; [|tauto].
- unfold lift; f_equal.
- assert (vartlookup d1 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1).
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H4 (s, DTfloat) neq2).
- assert (vartlookup d3 (s, DTfloat) <> None).
- apply
- (vartlookup_list_env_iter
- s
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad i))
- (bounded_seq0 n) d1); trivial.
- intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H10 (s, DTfloat) H11).
- assert (vartlookup d4 (s, DTfloat) <> None).
- apply
- (vartlookup_list_env_iter
- s
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad i))
- (bounded_seq0 n) d2); trivial.
- intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H11 (s, DTfloat) H12).
- unfold subvar in IHdf1; simpl in IHdf1.
- match_option_in IHdf1; [|tauto].
- match_option_in IHdf1; [|tauto].
- rewrite (split_subvar d1 d3 d d5); trivial.
- rewrite (split_subvar d2 d4 d0 d6); trivial.
- rewrite H6.
- apply Rplus_eq_compat_r .
- apply (list_env_iter_subvar_env2
- s d5 d6 {m' : nat | m' < n}
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad i))
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad i))
- (bounded_seq0 n) d1 d2 d3 d4); trivial.
- intros.
- specialize (IHdf2 (transpose grad i) env1 env2).
- rewrite H12, H13 in IHdf2.
- cut_to IHdf2; trivial.
- rewrite H14, H15 in IHdf2.
- unfold lift in IHdf2.
- now inversion IHdf2.
- discriminate.
- discriminate.
- intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H12 (s, DTfloat) H13).
- intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H12 (s, DTfloat) H13).
- - Case "MatrixMult"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df1 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- specialize (IHdf1
- (matrix_mult grad (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < p}) => d1 j i))
- grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, H3 in IHdf1.
- assert (df_eval_backprop_deriv σ df1 grad_env1
- (matrix_mult grad (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < p}) => d1 j i)) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2
- (matrix_mult grad (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < p}) => d1 j i)) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- rewrite eqq2, eqq3 in IHdf1.
- unfold lift in IHdf1; inversion IHdf1.
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1).
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2).
- specialize (IHdf2
- (matrix_mult (fun (i : {n' : nat | n' < p})
- (j : {m' : nat | m' < m}) => d0 j i) grad)
- d3 d4 H6 H8 H0).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- unfold lift.
- assert (df_eval_backprop_deriv
- σ df2 d3
- (matrix_mult (fun (i : {n' : nat | n' < p})
- (j : {m' : nat | m' < m}) => d0 j i) grad) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv
- σ df2 d4
- (matrix_mult (fun (i : {n' : nat | n' < p})
- (j : {m' : nat | m' < m}) => d0 j i) grad) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- unfold lift in IHdf1.
- unfold lift in IHdf2.
- rewrite eqq6, eqq7 in IHdf2; inversion IHdf2.
- rewrite (split_subvar d3 d7 d d5); trivial.
- rewrite (split_subvar d4 d8 d2 d6); trivial.
- now rewrite H7, H12.
- - Case "VectorPlus"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- unfold lift.
- assert (df_eval_backprop_deriv σ df2 d0 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d2 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1 in IHdf1.
- specialize (IHdf2 grad d0 d2).
- assert (vartlookup d0 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1).
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq2).
- specialize (IHdf2 H5 H6 H0).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- rewrite eqq0, eqq2 in IHdf1; unfold lift in IHdf1; inversion IHdf1.
- rewrite eqq3, eqq4 in IHdf2; unfold lift in IHdf2; inversion IHdf2.
- rewrite (split_subvar d0 d3 d d5); trivial.
- rewrite (split_subvar d2 d4 d1 d6); trivial.
- lra.
- - Case "VectorMinus"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- unfold lift.
- assert (df_eval_backprop_deriv
- σ df2 d0
- (fun i : {n' : nat | n' < n} => (- grad i)%R) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d2
- (fun i : {n' : nat | n' < n} => (- grad i)%R) <> None)
- by (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1 in IHdf1.
- specialize (IHdf2 (fun i : {n' : nat | n' < n} => (- grad i)%R) d0 d2).
- assert (vartlookup d0 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1).
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq2).
- specialize (IHdf2 H5 H6 H0).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- rewrite eqq0, eqq2 in IHdf1; unfold lift in IHdf1; inversion IHdf1.
- rewrite eqq3, eqq4 in IHdf2; unfold lift in IHdf2; inversion IHdf2.
- rewrite (split_subvar d0 d3 d d5); trivial.
- rewrite (split_subvar d2 d4 d1 d6); trivial.
- lra.
- - Case "MatrixPlus"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- unfold lift.
- assert (df_eval_backprop_deriv σ df2 d0 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d2 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1 in IHdf1.
- specialize (IHdf2 grad d0 d2).
- assert (vartlookup d0 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1).
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq2).
- specialize (IHdf2 H5 H6 H0).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- rewrite eqq0, eqq2 in IHdf1; unfold lift in IHdf1; inversion IHdf1.
- rewrite eqq3, eqq4 in IHdf2; unfold lift in IHdf2; inversion IHdf2.
- rewrite (split_subvar d0 d3 d d5); trivial.
- rewrite (split_subvar d2 d4 d1 d6); trivial.
- lra.
- - Case "MatrixMinus"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env1 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2 grad <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- unfold lift.
- assert (df_eval_backprop_deriv
- σ df2 d0
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad i j)%R)
- <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d2
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad i j)%R)
- <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- specialize (IHdf1 grad grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, eqq1 in IHdf1.
- specialize (IHdf2
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad i j)%R)
- d0 d2).
- assert (vartlookup d0 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) neq1).
- assert (vartlookup d2 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq2).
- specialize (IHdf2 H5 H6 H0).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- rewrite eqq0, eqq2 in IHdf1; unfold lift in IHdf1; inversion IHdf1.
- rewrite eqq3, eqq4 in IHdf2; unfold lift in IHdf2; inversion IHdf2.
- rewrite (split_subvar d0 d3 d d5); trivial.
- rewrite (split_subvar d2 d4 d1 d6); trivial.
- lra.
- - Case "VectorScalMult"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df1 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- specialize (IHdf1
- (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad j)%R))
- grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, H3 in IHdf1.
- assert (df_eval_backprop_deriv
- σ df1 grad_env1
- (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad j)%R)) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2
- (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad j)%R))
- <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1).
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2).
- specialize (IHdf2 (fun j : {n' : nat | n' < n} => (d0 * grad j)%R)
- d3 d4 H6 H7 H0).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- unfold lift.
- assert (df_eval_backprop_deriv σ df2 d3
- (fun j : {n' : nat | n' < n} => (d0 * grad j)%R)
- <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d4
- (fun j : {n' : nat | n' < n} => (d0 * grad j)%R)
- <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- unfold lift in IHdf1.
- rewrite eqq2, eqq3 in IHdf1; inversion IHdf1.
- rewrite eqq6, eqq7 in IHdf2; inversion IHdf2.
- rewrite (split_subvar d3 d7 d d5); trivial.
- rewrite (split_subvar d4 d8 d2 d6); trivial.
- now rewrite H11, H12.
- - Case "MatrixScalMult"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df1 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- specialize (IHdf1
- (msum
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (d1 i j * grad i j)%R))
- grad_env1 grad_env2 neq1 neq2 H).
- rewrite eqq, H3 in IHdf1.
- assert (df_eval_backprop_deriv
- σ df1 grad_env1
- (msum
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (d1 i j * grad i j)%R)) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2
- (msum
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (d1 i j * grad i j)%R)) <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (vartlookup d3 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) neq1).
- assert (vartlookup d4 (s, DTfloat) <> None) by
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat) neq2).
- specialize (IHdf2 (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad i j * d0)%R)
- d3 d4 H6 H7 H0).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- unfold lift.
- assert (df_eval_backprop_deriv σ df2 d3
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad i j * d0)%R)
- <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d4
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad i j * d0)%R)
- <> None) by
- (apply backprop_deriv_fully_closed_not_none; trivial).
- match_option; [|tauto].
- f_equal.
- unfold lift in IHdf1.
- rewrite eqq2, eqq3 in IHdf1; inversion IHdf1.
- rewrite eqq6, eqq7 in IHdf2; inversion IHdf2.
- rewrite (split_subvar d3 d7 d d5); trivial.
- rewrite (split_subvar d4 d8 d2 d6); trivial.
- now rewrite H11, H12.
- - Case "VectorApply"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- match_option.
- specialize (IHdf1 v0 grad_env1 grad_env2 neq1 neq2 H0).
- now rewrite eqq, H2 in IHdf1.
- - Case "MatrixApply"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- match_option.
- specialize (IHdf1 m0 grad_env1 grad_env2 neq1 neq2 H0).
- now rewrite eqq, H2 in IHdf1.
- - Case "VLossfun"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- match_option.
- specialize (IHdf1 v grad_env1 grad_env2 neq1 neq2 H0).
- now rewrite eqq, H2 in IHdf1.
- - Case "MLossfun"%string.
- destruct H.
- match_option; [|tauto].
- assert (df_eval σ df2 <> None) by
- (apply eval_fully_closed_not_none;trivial).
- match_option; [|tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros|tauto].
- match_option.
- specialize (IHdf1 m0 grad_env1 grad_env2 neq1 neq2 H0).
- now rewrite eqq, H2 in IHdf1.
- Qed.
-
- Lemma backprop_exchange_order {T} (σ:df_env) (df1 df2 :DefinedFunction UnitAnn T) (s: SubVar)
- (env:df_env) (grad1 grad2 : definition_function_types_interp T) :
- let v := (s, DTfloat) in
- vartlookup env v <> None ->
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df1 vl ->
- fully_closed_over df2 vl ->
- match
- df_eval_backprop_deriv σ df1 env grad1, df_eval_backprop_deriv σ df2 env grad2, vartlookup env v with
- | Some env1, Some env2, Some oval =>
- lift (fun e => subvar v e oval)
- (df_eval_backprop_deriv σ df1 env2 grad1) =
- lift (fun e => subvar v e oval)
- (df_eval_backprop_deriv σ df2 env1 grad2)
- | _, _, _ => True
- end.
- Proof.
- intros.
- do 3 match_option.
- unfold lift.
- do 2 match_option.
- - assert (vartlookup d0 v <> None); simpler2.
- assert (vartlookup d v <> None); simpler2.
- case_eq (vartlookup d0 v); [intros|tauto].
- case_eq (vartlookup d v); [intros|tauto].
- f_equal. subst v.
- rewrite (split_subvar d0 d2 d1 d4); trivial.
- rewrite (split_subvar d d3 d1 d5); trivial.
- generalize (backprop_indep_env σ df1 s env d0 grad1); intros.
- generalize (backprop_indep_env σ df2 s env d grad2); intros.
- simpl in H6; simpl in H7.
- cut_to H6; trivial; try congruence.
- cut_to H7; trivial; try congruence.
- unfold df_eval_backprop_delta in *.
- rewrite eqq1, H4 in H6.
- rewrite eqq1, H5 in H7.
- unfold lift in H6; simpl in H6.
- unfold lift in H7; simpl in H7.
- rewrite eqq, eqq2 in H6.
- rewrite eqq0, eqq3 in H7.
- inversion H6; inversion H7.
- rewrite H9, H10; lra.
- - assert (df_eval_backprop_deriv σ df2 d grad2 <> None).
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- - assert (df_eval_backprop_deriv σ df1 d0 grad1 <> None).
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- Qed.
-
- Lemma list_env_iter_id {A} (env : df_env) (l : list A) :
- list_env_iter (fun (_ : A) (env : df_env) => Some env)
- (Some env) l = Some env.
- Proof.
- now induction l.
- Qed.
-
- Lemma backprop_grad_sum_list_env_iter {m} (σ:df_env)
- (vecdf:Vector (DefinedFunction UnitAnn DTfloat) m) (s: SubVar)
- (grad_env1 grad_env2 grad_env3:df_env)
- (grad1 grad2 : (Vector float m))
- (val1 val2 val3 : float)
- (l : list {m' | m' < m} )
- :
- let v := (s, DTfloat) in
- (forall (i:{m' | m' < m}),
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over (vecdf i) vl) ->
- (forall (i:{m' | m' < m}) (env1 env2 env3 : df_env),
- vartlookup env1 v <> None ->
- vartlookup env2 v <> None ->
- vartlookup env3 v <> None ->
-
- df_eval_backprop_delta σ (vecdf i) v env3
- (grad1 i + grad2 i)%R =
- lift2 dfti_gen_plus
- (df_eval_backprop_delta σ (vecdf i) v env1 (grad1 i))
- (df_eval_backprop_delta σ (vecdf i) v env2 (grad2 i))) ->
-
- vartlookup grad_env1 v = Some val1 ->
- vartlookup grad_env2 v = Some val2 ->
- vartlookup grad_env3 v = Some val3 ->
-
- lift (fun e : df_env => subvar v e val3)
- (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env (grad1 j + grad2 j)%R)
- (Some grad_env3) l) =
- lift2 dfti_gen_plus
- (lift (fun e : df_env => subvar v e val1)
- (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env (grad1 j)) (Some grad_env1) l))
- (lift (fun e : df_env => subvar v e val2)
- (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env (grad2 j)) (Some grad_env2) l)).
- Proof.
- intros.
- revert val1 val2 val3 grad_env1 grad_env2 grad_env3 H1 H2 H3.
- induction l.
- - intros.
- simpl; f_equal.
- unfold subvar; simpl.
- rewrite H1,H2,H3.
- lra.
- - intros.
- simpl.
- unfold df_eval_backprop_delta in H0.
- unfold lift, lift2.
- assert (df_eval_backprop_deriv σ (vecdf a) grad_env3 (grad1 a + grad2 a)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- case_eq (df_eval_backprop_deriv σ (vecdf a) grad_env3 (grad1 a + grad2 a)%R)
- ; [intros | tauto].
- assert (df_eval_backprop_deriv σ (vecdf a) grad_env1 (grad1 a)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- case_eq (df_eval_backprop_deriv σ (vecdf a) grad_env1 (grad1 a)%R)
- ; [intros | tauto].
- assert (df_eval_backprop_deriv σ (vecdf a) grad_env2 (grad2 a)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- case_eq (df_eval_backprop_deriv σ (vecdf a) grad_env2 (grad2 a)%R)
- ; [intros | tauto].
-
- assert (vartlookup d v <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 v); congruence.
- assert (vartlookup d0 v <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H7 v); congruence.
- assert (vartlookup d1 v <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H9 v); congruence.
-
- case_eq (vartlookup d v); [intros v3 eq3 |tauto].
- case_eq (vartlookup d0 v); [intros v1 eq1 |tauto].
- case_eq (vartlookup d1 v); [intros v2 eq2 |tauto].
-
- specialize (IHl v1 v2 v3 d0 d1 d eq1 eq2 eq3).
- unfold lift, lift2.
- match_option.
- case_eq (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env (grad1 j)) (Some d0) l).
- case_eq (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env (grad2 j)) (Some d1) l).
- + intros.
- rewrite eqq, H13, H14 in IHl.
- unfold lift, lift2 in IHl; simpl in IHl.
- f_equal.
- subst v.
- rewrite (split_subvar d d2 val3 v3); trivial.
- rewrite (split_subvar d0 d4 val1 v1); trivial.
- rewrite (split_subvar d1 d3 val2 v2); trivial.
- inversion IHl.
- rewrite H16.
- specialize (H0 a grad_env1 grad_env2 grad_env3).
- cut_to H0; try congruence.
- rewrite H1,H2,H3 in H0.
- rewrite H5,H7,H9 in H0.
- unfold lift, lift2 in H0.
- inversion H0.
- rewrite H17; lra.
- + intros.
- assert (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env (grad2 j)) (Some d1) l <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- tauto.
- + intros.
- assert (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env (grad1 j)) (Some d0) l <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- tauto.
- + assert (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env (grad1 j + grad2 j)%R) (Some d) l <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- tauto.
- Qed.
-
- Lemma backprop_mat_grad_sum_list_env_iter {m n} (σ:df_env)
- (df : DefinedFunction UnitAnn (DTVector n)) (s: SubVar)
- (grad_env1 grad_env2 grad_env3:df_env)
- (grad1 grad2 : (Matrix float m n))
- (val1 val2 val3 : float)
- (l : list {m' | m' < m} )
- :
- let v := (s, DTfloat) in
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl ->
- (forall (i:{m' | m' < m}) (env1 env2 env3 : df_env),
- vartlookup env1 v <> None ->
- vartlookup env2 v <> None ->
- vartlookup env3 v <> None ->
-
- df_eval_backprop_delta σ df v env3 (fun j => (grad1 i j + grad2 i j)%R) =
- lift2 dfti_gen_plus
- (df_eval_backprop_delta σ df v env1 (grad1 i))
- (df_eval_backprop_delta σ df v env2 (grad2 i))) ->
-
- vartlookup grad_env1 v = Some val1 ->
- vartlookup grad_env2 v = Some val2 ->
- vartlookup grad_env3 v = Some val3 ->
-
- lift (fun e : df_env => subvar v e val3)
- (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ df env (fun k => (grad1 j k + grad2 j k)%R) )
- (Some grad_env3) l) =
- lift2 dfti_gen_plus
- (lift (fun e : df_env => subvar v e val1)
- (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ df env (grad1 j)) (Some grad_env1) l))
- (lift (fun e : df_env => subvar v e val2)
- (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ df env (grad2 j)) (Some grad_env2) l)).
- Proof.
- intros.
- revert val1 val2 val3 grad_env1 grad_env2 grad_env3 H1 H2 H3.
- induction l.
- - intros.
- simpl; f_equal.
- unfold subvar; simpl.
- rewrite H1,H2,H3.
- lra.
- - intros.
- unfold df_eval_backprop_delta in *.
- simpl.
- unfold lift, lift2.
- assert (df_eval_backprop_deriv σ df grad_env3
- (fun k : {n' : nat | n' < n} => (grad1 a k + grad2 a k)%R) <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- case_eq (df_eval_backprop_deriv σ df grad_env3
- (fun k : {n' : nat | n' < n} => (grad1 a k + grad2 a k)%R))
- ; [intros | tauto].
- assert (df_eval_backprop_deriv σ df grad_env1 (grad1 a)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- case_eq (df_eval_backprop_deriv σ df grad_env1 (grad1 a)%R)
- ; [intros | tauto].
- assert (df_eval_backprop_deriv σ df grad_env2 (grad2 a)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- case_eq (df_eval_backprop_deriv σ df grad_env2 (grad2 a)%R)
- ; [intros | tauto].
-
- assert (vartlookup d v <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H5 v); congruence.
- assert (vartlookup d0 v <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H7 v); congruence.
- assert (vartlookup d1 v <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H9 v); congruence.
-
- case_eq (vartlookup d v); [intros v3 eq3 |tauto].
- case_eq (vartlookup d0 v); [intros v1 eq1 |tauto].
- case_eq (vartlookup d1 v); [intros v2 eq2 |tauto].
-
- specialize (IHl v1 v2 v3 d0 d1 d eq1 eq2 eq3).
- match_option.
- case_eq (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ df env (grad1 j)) (Some d0) l).
- case_eq (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ df env (grad2 j)) (Some d1) l).
- + intros.
- rewrite eqq, H13, H14 in IHl.
- unfold lift, lift2 in IHl; simpl in IHl.
- f_equal.
- subst v.
- rewrite (split_subvar d d2 val3 v3); trivial.
- rewrite (split_subvar d0 d4 val1 v1); trivial.
- rewrite (split_subvar d1 d3 val2 v2); trivial.
- inversion IHl.
- rewrite H16.
- specialize (H0 a grad_env1 grad_env2 grad_env3).
- cut_to H0; try congruence.
- rewrite H1,H2,H3 in H0.
- rewrite H5,H7,H9 in H0.
- unfold lift, lift2 in H0.
- inversion H0.
- rewrite H17; lra.
- + intros.
- assert (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ df env (grad2 j)) (Some d1) l <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- tauto.
- + intros.
- assert (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ df env (grad1 j)) (Some d0) l <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- tauto.
- + assert (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ df env
- (fun k : {n' : nat | n' < n} => (grad1 j k + grad2 j k)%R))
- (Some d) l <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- apply H.
- tauto.
- Qed.
-
- Lemma matrix_zip_m_n {T} {m n} {i j} {m1 m2 : Matrix T m n} :
- matrix_zip m1 m2 i j = (m1 i j, m2 i j).
- Proof.
- unfold matrix_zip.
- rewrite vmap_nth; simpl.
- now unfold vector_zip.
- Qed.
-
- Lemma backprop_grad_sum {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (s: SubVar)
- (grad_env1 grad_env2 grad_env3:df_env)
- (grad1 grad2 : definition_function_types_interp T) :
- let v := (s, DTfloat) in
- vartlookup grad_env1 v <> None ->
- vartlookup grad_env2 v <> None ->
- vartlookup grad_env3 v <> None ->
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl ->
- df_eval_backprop_delta σ df (s,DTfloat) grad_env3 (dfti_gen_plus grad1 grad2) =
- lift2 dfti_gen_plus
- (df_eval_backprop_delta σ df (s,DTfloat) grad_env1 grad1)
- (df_eval_backprop_delta σ df (s,DTfloat) grad_env2 grad2).
- Proof.
- unfold df_eval_backprop_delta.
- revert grad_env1 grad_env2 grad_env3.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case
-(*
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case
-*)
- ; simpl; intros.
- - Case "Number"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- unfold subvar; simpl.
- rewrite eqq, eqq0, eqq1.
- f_equal; lra.
- - Case "Constant"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- unfold subvar; simpl.
- rewrite eqq, eqq0, eqq1.
- f_equal; lra.
- - Case "DVector"%string.
- unfold two_vector_env_iter_alt in *.
- rewrite vforall_forall in H3.
- revert grad_env1 grad_env2 grad_env3 H0 H1 H2.
- induction (bounded_seq0 n).
- + intros.
- simpl.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- unfold subvar; simpl.
- rewrite eqq, eqq0, eqq1.
- f_equal; lra.
- + intros.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- unfold lift; simpl in *.
- assert (df_eval_backprop_deriv σ (x a) grad_env3 (grad1 a + grad2 a)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- assert (df_eval_backprop_deriv σ (x a) grad_env1 (grad1 a)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- assert (df_eval_backprop_deriv σ (x a) grad_env2 (grad2 a)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- case_eq (df_eval_backprop_deriv σ (x a) grad_env3 (grad1 a + grad2 a)%R)
- ; [intros|tauto].
- case_eq (df_eval_backprop_deriv σ (x a) grad_env1 (grad1 a)%R)
- ; [intros|tauto].
- case_eq (df_eval_backprop_deriv σ (x a) grad_env2 (grad2 a)%R)
- ; [intros|tauto].
- assert (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (grad1 i + grad2 i)%R)
- (Some d2) l <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- apply (H3 a0).
- match_option; [|tauto].
- assert (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (grad1 i)) (Some d3)
- l <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- apply (H3 a0).
- match_option; [|tauto].
- assert (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (grad2 i)) (Some d4)
- l <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- apply (H3 a0).
- match_option; [|tauto].
- unfold lift2; simpl.
-
- assert (vartlookup d2 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H7 (s, DTfloat) H2).
- assert (vartlookup d3 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H8 (s, DTfloat) H0).
- assert (vartlookup d4 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H9 (s, DTfloat) H1).
-
- case_eq (vartlookup d2 (s, DTfloat)); [intros|tauto].
- case_eq (vartlookup d3 (s, DTfloat)); [intros|tauto].
- case_eq (vartlookup d4 (s, DTfloat)); [intros|tauto].
-
- rewrite (split_subvar d2 d5 d d8); trivial.
- rewrite (split_subvar d3 d6 d0 d9); trivial.
- rewrite (split_subvar d4 d7 d1 d10); trivial.
-
- f_equal.
-
- specialize (IHl d3 d4 d2 H14 H15 H13).
- rewrite H16, H17, H18 in IHl.
- rewrite eqq2, eqq3, eqq4 in IHl.
- unfold lift, lift2 in IHl.
- inversion IHl.
- rewrite H20.
-
- specialize (H a (grad1 a) (grad2 a) grad_env1 grad_env2 grad_env3 H0 H1 H2).
- specialize (H (H3 a)).
-
- rewrite eqq,eqq0,eqq1 in H.
- rewrite H7, H8, H9 in H.
- unfold lift, lift2 in H.
- inversion H.
- rewrite H21.
- lra.
- - Case "DMatrix"%string.
- unfold two_matrix_env_iter_alt in *.
- rewrite vforall_forall in H3.
- revert grad_env1 grad_env2 grad_env3 H0 H1 H2.
- induction (bounded_seq0 n); induction (bounded_seq0 m).
- + intros.
- simpl.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- unfold subvar; simpl.
- rewrite eqq, eqq0, eqq1.
- f_equal; lra.
- + intros.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- simpl.
- unfold lift, lift2; simpl; f_equal.
- unfold subvar; simpl.
- rewrite eqq, eqq0, eqq1; simpl; lra.
- + intros.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- rewrite list_env_iter_id.
- rewrite list_env_iter_id.
- rewrite list_env_iter_id.
- unfold subvar; simpl.
- rewrite eqq, eqq0, eqq1; simpl; f_equal; lra.
- + intros.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- simpl.
- assert (df_eval_backprop_deriv σ (x a a0) grad_env3 (grad1 a a0 + grad2 a a0)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a); rewrite vforall_forall in H3; apply H3.
- assert (df_eval_backprop_deriv σ (x a a0) grad_env1 (grad1 a a0)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a); rewrite vforall_forall in H3; apply H3.
- assert (df_eval_backprop_deriv σ (x a a0) grad_env2 (grad2 a a0)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a); rewrite vforall_forall in H3; apply H3.
- case_eq (df_eval_backprop_deriv σ (x a a0) grad_env3 (grad1 a a0 + grad2 a a0)%R)
- ; [intros|tauto].
- case_eq (df_eval_backprop_deriv σ (x a a0) grad_env1 (grad1 a a0)%R)
- ; [intros|tauto].
- case_eq (df_eval_backprop_deriv σ (x a a0) grad_env2 (grad2 a a0)%R)
- ; [intros|tauto].
- assert (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (x a j) env (grad1 a j + grad2 a j)%R)
- (Some d2) l0 <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a); rewrite vforall_forall in H3; apply H3.
- case_eq (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (x a j) env (grad1 a j + grad2 a j)%R)
- (Some d2) l0 ); [intros | tauto].
- assert (list_env_iter
- (fun (i : {n' : nat | n' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (x a i) env (grad1 a i)) (Some d3)
- l0 <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a); rewrite vforall_forall in H3; apply H3.
- case_eq (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (x a j) env (grad1 a j)%R)
- (Some d3) l0 ); [intros | tauto].
- assert (list_env_iter
- (fun (i : {n' : nat | n' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (x a i) env (grad2 a i)) (Some d4)
- l0 <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a); rewrite vforall_forall in H3; apply H3.
- case_eq (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (x a j) env (grad2 a j)%R)
- (Some d4) l0 ); [intros | tauto].
- assert
- (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad1 i j + grad2 i j)%R)
- (df_eval_backprop_deriv σ (x i a0) env (grad1 i a0 + grad2 i a0)%R) l0)
- (Some d5) l <> None).
- apply list_env_iter_total_fun; intros.
- assert (df_eval_backprop_deriv σ (x a1 a0) env0 (grad1 a1 a0 + grad2 a1 a0)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a1); rewrite vforall_forall in H3; apply H3.
- case_eq (df_eval_backprop_deriv σ (x a1 a0) env0 (grad1 a1 a0 + grad2 a1 a0)%R); [intros|tauto].
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a1); rewrite vforall_forall in H3; apply H3.
- unfold lift; simpl.
- match_option; [|tauto].
-
- assert
- (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad1 i j)%R)
- (df_eval_backprop_deriv σ (x i a0) env (grad1 i a0)%R) l0)
- (Some d6) l <> None).
- apply list_env_iter_total_fun; intros.
- assert (df_eval_backprop_deriv σ (x a1 a0) env0 (grad1 a1 a0)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a1); rewrite vforall_forall in H3; apply H3.
- case_eq (df_eval_backprop_deriv σ (x a1 a0) env0 (grad1 a1 a0)%R); [intros|tauto].
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a1); rewrite vforall_forall in H3; apply H3.
- match_option; [|tauto].
-
- assert
- (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (grad2 i j)%R)
- (df_eval_backprop_deriv σ (x i a0) env (grad2 i a0)%R) l0)
- (Some d7) l <> None).
- apply list_env_iter_total_fun; intros.
- assert (df_eval_backprop_deriv σ (x a1 a0) env0 (grad2 a1 a0)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a1); rewrite vforall_forall in H3; apply H3.
- case_eq (df_eval_backprop_deriv σ (x a1 a0) env0 (grad2 a1 a0)%R); [intros|tauto].
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- specialize (H3 a1); rewrite vforall_forall in H3; apply H3.
- match_option; [|tauto].
-
- unfold lift2; simpl; f_equal.
-
- assert (vartlookup d2 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H7 (s, DTfloat) H2).
- assert (vartlookup d3 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H8 (s, DTfloat) H0).
- assert (vartlookup d4 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H9 (s, DTfloat) H1).
-
- case_eq (vartlookup d2 (s, DTfloat)); [intros|tauto].
- case_eq (vartlookup d3 (s, DTfloat)); [intros|tauto].
- case_eq (vartlookup d4 (s, DTfloat)); [intros|tauto].
-
- assert (Hc := H).
- assert (H3c := H3).
-
- specialize (H a a0 (grad1 a a0) (grad2 a a0) grad_env1 grad_env2 grad_env3
- H0 H1 H2).
-
- specialize (H3 a).
- rewrite vforall_forall in H3.
- specialize (H (H3 a0)).
- rewrite eqq, eqq0, eqq1 in H; simpl in H.
- rewrite H7, H8, H9 in H; unfold lift, lift2 in H; simpl in H.
-
- assert (vartlookup d5 (s, DTfloat) <> None).
- apply (vartlookup_list_env_iter
- s (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (x a j) env (grad1 a j + grad2 a j)%R)
- l0 d2 d5); trivial; intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H25 (s, DTfloat) H26).
- case_eq (vartlookup d5 (s, DTfloat)); [intros|tauto].
- assert (vartlookup d6 (s, DTfloat) <> None).
- apply (vartlookup_list_env_iter
- s (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (x a j) env (grad1 a j)%R)
- l0 d3 d6); trivial; intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H27 (s, DTfloat) H28).
- case_eq (vartlookup d6 (s, DTfloat)); [intros|tauto].
- assert (vartlookup d7 (s, DTfloat) <> None).
- apply (vartlookup_list_env_iter
- s (fun (j : {m' : nat | m' < m}) (env : df_env) =>
- df_eval_backprop_deriv σ (x a j) env (grad2 a j)%R)
- l0 d4 d7); trivial; intros.
- apply (df_eval_backprop_deriv_preserves_lookup_not_none H29 (s, DTfloat) H30).
- case_eq (vartlookup d7 (s, DTfloat)); [intros|tauto].
-
- specialize (IHl d6 d7 d5 H27 H29 H25).
- rewrite H28, H30, H26 in IHl; simpl in IHl.
-
- rewrite eqq2,eqq3,eqq4 in IHl.
- unfold lift, lift2 in IHl; simpl in IHl.
-
- rewrite (split_subvar d5 d8 d d14); trivial.
- rewrite (split_subvar d6 d9 d0 d15); trivial.
- rewrite (split_subvar d7 d10 d1 d16); trivial.
-
- rewrite (split_subvar d2 d5 d d11); trivial.
- rewrite (split_subvar d3 d6 d0 d12); trivial.
- rewrite (split_subvar d4 d7 d1 d13); trivial.
-
- inversion H.
- inversion IHl.
- rewrite H32, H33.
-
- generalize (backprop_grad_sum_list_env_iter
- σ (x a) s d3 d4 d2 (grad1 a) (grad2 a)
- d12 d13 d11 l0); intros.
- specialize (H31 H3).
- cut_to H31.
- * rewrite H11,H13,H15 in H31.
- unfold lift, lift2 in H31.
- inversion H31.
- rewrite H35; lra.
- * intros.
- unfold df_eval_backprop_delta.
- specialize (Hc a i (grad1 a i) (grad2 a i) env1 env2 env3).
- specialize (Hc H34 H35 H36).
- specialize (Hc (H3 i)).
- apply Hc.
- * trivial.
- * trivial.
- * trivial.
- - Case "Var"%string.
- match_option; [|tauto].
- case_eq (vartlookup grad_env1 (s, DTfloat)); [intros| tauto].
- case_eq (vartlookup grad_env2 (s, DTfloat)); [intros| tauto].
- destruct (v == (s,DTfloat)).
- + invcs e.
- rewrite eqq, H3, H4; simpl.
- f_equal.
- rewrite subvar_addvar_scalar_eq; trivial.
- rewrite subvar_addvar_scalar_eq; trivial.
- rewrite subvar_addvar_scalar_eq; trivial.
- + assert (v<> (s, DTfloat)) by congruence.
- match_option.
- * match_option.
- -- match_option; unfold lift, lift2; simpl; f_equal.
- ++ rewrite subvar_addvar_scalar_neq; trivial.
- rewrite subvar_addvar_scalar_neq; trivial.
- rewrite subvar_addvar_scalar_neq; trivial.
- lra.
- ++ rewrite subvar_addvar_scalar_neq; trivial.
- rewrite subvar_addvar_scalar_neq; trivial.
- unfold subvar; simpl; rewrite H4; lra.
- -- unfold lift, lift2; simpl; f_equal.
- rewrite subvar_addvar_scalar_neq; trivial.
- case_eq (vartlookup grad_env2 v); intros.
- ++ rewrite subvar_addvar_scalar_neq; trivial.
- f_equal; unfold subvar; simpl.
- rewrite H3; lra.
- ++ unfold subvar; simpl.
- rewrite H3, H4.
- f_equal; lra.
- * match_option.
- -- match_option; unfold lift, lift2; simpl; f_equal.
- ++ rewrite subvar_addvar_scalar_neq; trivial.
- rewrite subvar_addvar_scalar_neq; trivial.
- unfold subvar; simpl.
- rewrite eqq; lra.
- ++ rewrite subvar_addvar_scalar_neq; trivial.
- unfold subvar; simpl.
- rewrite eqq, H4; lra.
- -- match_option; unfold lift, lift2; simpl; f_equal.
- ++ rewrite subvar_addvar_scalar_neq; trivial.
- unfold subvar; simpl.
- rewrite eqq, H3; lra.
- ++ unfold subvar; simpl.
- rewrite eqq, H3, H4; lra.
- - Case "Plus"%string.
- destruct H2.
- unfold lift, lift2.
- repeat simpl_closed_backprop.
- simpler2.
- specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2 grad1 grad2 env1 env3 env).
- simpl in IHdf1.
- rewrite eqq5, eqq6, eqq7 in IHdf1.
- rewrite eqq, eqq1, eqq3 in IHdf1.
- unfold lift, lift2 in IHdf1; simpl in IHdf1.
- inversion IHdf1.
- cut_to IHdf2; simpler2.
- simpl in IHdf2.
- rewrite eqq0, eqq2, eqq4 in IHdf2.
- unfold lift, lift2 in IHdf2.
- inversion IHdf2.
- f_equal.
- rewrite (split_subvar env env0 val val2); trivial.
- rewrite (split_subvar env1 env2 val0 val3); trivial.
- rewrite (split_subvar env3 env4 val1 val4); trivial.
- rewrite H5, H6; lra.
- - Case "Minus"%string.
- destruct H2.
- unfold lift, lift2.
- repeat simpl_closed_backprop.
- simpler2.
- specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2 (-grad1)%R (-grad2)%R env1 env3 env).
- simpl in IHdf1.
- rewrite eqq5, eqq6, eqq7 in IHdf1.
- rewrite eqq, eqq1, eqq3 in IHdf1.
- unfold lift, lift2 in IHdf1; simpl in IHdf1.
- inversion IHdf1.
- cut_to IHdf2; simpler2.
- simpl in IHdf2.
- replace (- grad1 + - grad2)%R with (- (grad1 + grad2))%R in IHdf2 by lra.
- rewrite eqq0, eqq2, eqq4 in IHdf2.
- unfold lift, lift2 in IHdf2; simpl in IHdf2.
- inversion IHdf2.
- rewrite (split_subvar env env0 val val2); trivial.
- rewrite (split_subvar env1 env2 val0 val3); trivial.
- rewrite (split_subvar env3 env4 val1 val4); trivial.
- f_equal; rewrite H5, H6; lra.
- - Case "Times"%string.
- destruct H2.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df1); intros.
- specialize (H4 H2).
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H5 H3).
- match_option; [|tauto].
- unfold lift, lift2.
- repeat simpl_closed_backprop.
- simpler2.
- specialize (IHdf1 (d1 * grad1)%R (d1 * grad2)%R grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2 (d0 * grad1)%R (d0 * grad2)%R env1 env3 env).
- rewrite eqq8, eqq9, eqq in IHdf1.
- simpl in IHdf1.
- replace (d1 *grad1 + d1*grad2)%R with (d1 * (grad1 + grad2))%R in IHdf1 by lra.
- rewrite eqq2, eqq4, eqq6 in IHdf1.
- unfold lift, lift2 in IHdf1; inversion IHdf1.
- cut_to IHdf2; simpler2.
- simpl in IHdf2.
- replace (d0 *grad1 + d0 *grad2)%R with (d0 *(grad1 + grad2))%R in IHdf2 by lra.
- rewrite eqq3,eqq5,eqq7 in IHdf2.
- unfold lift, lift2 in IHdf2; inversion IHdf2.
- f_equal.
- rewrite (split_subvar env env0 d val1); trivial.
- rewrite (split_subvar env1 env2 val val2); trivial.
- rewrite (split_subvar env3 env4 val0 val3); trivial.
- rewrite H7, H8; lra.
- - Case "Divide"%string.
- destruct H2.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df1); intros.
- specialize (H4 H2).
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H5 H3).
- match_option; [|tauto].
- unfold lift, lift2.
- repeat simpl_closed_backprop.
- simpler2.
- specialize (IHdf1 (grad1/d1)%R (grad2/d1)%R grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2 (-d0/(d1*d1) * grad1)%R (-d0/(d1*d1) * grad2)%R env1 env3 env).
- rewrite eqq, eqq8, eqq9 in IHdf1.
- simpl in IHdf1.
- replace (grad1/d1 + grad2/d1)%R with ((grad1 + grad2)/d1)%R in IHdf1 by lra.
- rewrite eqq2, eqq4, eqq6 in IHdf1.
- unfold lift, lift2 in IHdf1; inversion IHdf1.
- cut_to IHdf2; simpler2.
- simpl in IHdf2.
- replace (-d0/(d1*d1) *grad1 + -d0/(d1*d1) *grad2)%R with (-d0/(d1*d1) *(grad1 + grad2))%R in IHdf2 by lra.
- rewrite eqq3, eqq5, eqq7 in IHdf2.
- unfold lift, lift2 in IHdf2; inversion IHdf2.
- f_equal.
- rewrite (split_subvar env env0 d val1); trivial.
- rewrite (split_subvar env1 env2 val val2); trivial.
- rewrite (split_subvar env3 env4 val0 val3); trivial.
- rewrite H7, H8; lra.
- - Case "Square"%string.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df); intros.
- specialize (H3 H2).
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (2 * d0 *grad1)%R (2 * d0 *grad2)%R grad_env1 grad_env2 grad_env3).
- specialize (IHdf H H0 H1 H2).
- rewrite eqq, eqq1, eqq2 in IHdf.
- simpl in IHdf.
- replace (2 * d0 * grad1 + 2 * d0 * grad2)%R with (2 * d0 * (grad1 + grad2))%R in IHdf by lra.
- apply IHdf.
- - Case "Exp"%string.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df); intros.
- specialize (H3 H2).
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (grad1 * exp d0)%R (grad2*exp d0)%R grad_env1 grad_env2 grad_env3).
- specialize (IHdf H H0 H1 H2).
- rewrite eqq, eqq1, eqq2 in IHdf.
- simpl in IHdf.
- replace (grad1 * exp d0 + grad2 * exp d0)%R with ((grad1 + grad2) * exp d0)%R in IHdf by lra.
- apply IHdf.
- - Case "Log"%string.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df); intros.
- specialize (H3 H2).
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (grad1/d0)%R (grad2/d0)%R grad_env1 grad_env2 grad_env3).
- specialize (IHdf H H0 H1 H2).
- rewrite eqq, eqq1, eqq2 in IHdf.
- simpl in IHdf.
- replace (grad1/d0 + grad2/d0)%R with ((grad1 + grad2) / d0)%R in IHdf by lra.
- apply IHdf.
- - Case "Abs"%string.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df); intros.
- specialize (H3 H2).
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (grad1*sign d0)%R (grad2*sign d0)%R grad_env1 grad_env2 grad_env3).
- specialize (IHdf H H0 H1 H2).
- rewrite eqq, eqq1, eqq2 in IHdf.
- simpl in IHdf.
- replace (grad1*sign d0 + grad2*sign d0)%R with ((grad1 + grad2) * sign d0)%R in IHdf by lra.
- apply IHdf.
- - Case "Sign"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (0)%R (0)%R grad_env1 grad_env2 grad_env3).
- specialize (IHdf H H0 H1 H2).
- rewrite eqq, eqq0, eqq1 in IHdf.
- simpl in IHdf.
- replace (0 + 0)%R with 0%R in IHdf by lra.
- apply IHdf.
- - Case "PSign"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (0)%R (0)%R grad_env1 grad_env2 grad_env3).
- specialize (IHdf H H0 H1 H2).
- rewrite eqq, eqq0, eqq1 in IHdf.
- simpl in IHdf.
- replace (0 + 0)%R with 0%R in IHdf by lra.
- apply IHdf.
- - Case "Max"%string.
- destruct H2.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df1); intros.
- specialize (H4 H2).
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H5 H3).
- match_option; [|tauto].
- destruct (Rle_dec d0 d1).
- + match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf2 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H3).
- rewrite eqq, eqq2, eqq3 in IHdf2; simpl in IHdf2.
- apply IHdf2.
- + match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- rewrite eqq, eqq2, eqq3 in IHdf1; simpl in IHdf1.
- apply IHdf1.
- - Case "VectorDot"%string.
- destruct H2.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df1); intros.
- specialize (H4 H2).
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H5 H3).
- match_option; [|tauto].
- unfold lift, lift2.
- repeat simpl_closed_backprop.
- simpler2.
- specialize (IHdf1
- (vmap (fun rv : R => (rv * grad1)%R) d1)
- (vmap (fun rv : R => (rv * grad2)%R) d1)
- grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2
- (vmap (fun lv : R => (lv * grad1)%R) d0)
- (vmap (fun lv : R => (lv * grad2)%R) d0)
- env1 env3 env).
- rewrite eqq, eqq8, eqq9 in IHdf1.
- simpl in IHdf1.
- replace
- (fun i : {n' : nat | n' < n} =>
- (vmap (fun rv : R => rv * grad1) d1 i + vmap (fun rv : R => rv * grad2) d1 i)%R)
- with
- (vmap (fun rv : R => (rv * (grad1 + grad2))%R) d1)
- in IHdf1.
- + rewrite eqq2, eqq4, eqq6 in IHdf1.
- unfold lift, lift2 in IHdf1; inversion IHdf1.
- simpl in IHdf2.
- cut_to IHdf2; simpler2.
- replace (fun i : {n' : nat | n' < n} =>
- (vmap (fun lv : R => lv * grad1) d0 i +
- vmap (fun lv : R => lv * grad2) d0 i)%R) with
- (vmap (fun lv : R => (lv * (grad1 + grad2))%R) d0) in IHdf2.
- * rewrite eqq3, eqq5, eqq7 in IHdf2.
- unfold lift, lift2 in IHdf2; inversion IHdf2.
- f_equal.
- rewrite (split_subvar env env0 d val1); trivial.
- rewrite (split_subvar env1 env2 val val2); trivial.
- rewrite (split_subvar env3 env4 val0 val3); trivial.
- rewrite H7, H8; lra.
- * apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vmap_nth.
- rewrite vmap_nth.
- rewrite vmap_nth.
- lra.
- + apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vmap_nth.
- rewrite vmap_nth.
- rewrite vmap_nth.
- lra.
- - Case "VectorSum"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (ConstVector n grad1) (ConstVector n grad2) grad_env1 grad_env2 grad_env3).
- specialize (IHdf H H0 H1 H2).
- rewrite eqq, eqq0, eqq1 in IHdf.
- simpl in IHdf.
- unfold ConstVector in IHdf.
- unfold ConstVector.
- apply IHdf.
- - Case "MatrixSum"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (ConstMatrix m n grad1) (ConstMatrix m n grad2) grad_env1 grad_env2 grad_env3).
- specialize (IHdf H H0 H1 H2).
- rewrite eqq, eqq0, eqq1 in IHdf.
- simpl in IHdf.
- unfold ConstMatrix in IHdf.
- unfold ConstMatrix.
- apply IHdf.
- - Case "VectorElem"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf (fun k : {n' : nat | n' < n} =>
- if equiv_dec (` k) (` i) then grad1 else 0%R)
- (fun k : {n' : nat | n' < n} =>
- if equiv_dec (` k) (` i) then grad2 else 0%R)
- grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- rewrite eqq, eqq0, eqq1 in IHdf.
- simpl in IHdf.
- replace (fun k : {n' : nat | n' < n} =>
- if equiv_dec (` k) (` i) then (grad1 + grad2)%R else 0%R) with
- (fun i0 : {n' : nat | n' < n} =>
- ((if equiv_dec (` i0) (` i) then grad1 else 0) +
- (if equiv_dec (` i0) (` i) then grad2 else 0))%R).
- apply IHdf.
- apply FunctionalExtensionality.functional_extensionality; intros.
- destruct (equiv_dec (` x) (` i)); lra.
- - Case "MatrixElem"%string.
- match_option; [|tauto].
- match_option; [|tauto].
- match_option; [|tauto].
- specialize (IHdf
- (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) =>
- if equiv_dec (` k1) (` i)
- then if equiv_dec (` k2) (` j) then grad1 else 0%R
- else 0%R)
- (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) =>
- if equiv_dec (` k1) (` i)
- then if equiv_dec (` k2) (` j) then grad2 else 0%R
- else 0%R)
- grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- rewrite eqq, eqq0, eqq1 in IHdf.
- simpl in IHdf.
- replace (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) =>
- if equiv_dec (` k1) (` i)
- then if equiv_dec (` k2) (` j) then (grad1 + grad2)%R else 0%R
- else 0%R) with
- (fun (i0 : {n' : nat | n' < m}) (j0 : {m' : nat | m' < n}) =>
- ((if equiv_dec (` i0) (` i)
- then if equiv_dec (` j0) (` j) then grad1 else 0
- else 0) +
- (if equiv_dec (` i0) (` i)
- then if equiv_dec (` j0) (` j) then grad2 else 0
- else 0))%R).
- apply IHdf.
- apply FunctionalExtensionality.functional_extensionality; intros.
- destruct (equiv_dec (` x) (` i)).
- + apply FunctionalExtensionality.functional_extensionality; intros.
- destruct (equiv_dec (` x0) (` j)); lra.
- + apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- - Case "MatrixVectorMult"%string.
- destruct H2.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df1); intros.
- specialize (H4 H2).
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H5 H3).
- match_option; [|tauto].
- unfold lift, lift2.
- repeat simpl_closed_backprop.
- simpler2.
- specialize (IHdf1
- (fun (i : {n' : nat | n' < m})
- (j : {m' : nat | m' < n}) => (grad1 i * d1 j)%R)
- (fun (i : {n' : nat | n' < m})
- (j : {m' : nat | m' < n}) => (grad2 i * d1 j)%R)
- grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- rewrite eqq, eqq8, eqq9 in IHdf1.
- specialize (IHdf2
- (matrix_vector_mult
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => d0 j i) grad1)
- (matrix_vector_mult
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => d0 j i) grad2)
- env1 env3 env).
- simpl in IHdf1.
- replace
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (grad1 i * d1 j + grad2 i * d1 j)%R) with
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- ((grad1 i + grad2 i) * d1 j)%R) in IHdf1.
- + rewrite eqq2, eqq4, eqq6 in IHdf1.
- unfold lift, lift2 in IHdf1; inversion IHdf1.
- cut_to IHdf2; simpler2.
- simpl in IHdf2.
- replace
- (fun i : {n' : nat | n' < n} =>
- (matrix_vector_mult
- (fun (i0 : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) =>
- d0 j i0) grad1 i +
- matrix_vector_mult
- (fun (i0 : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) =>
- d0 j i0) grad2 i)%R) with
- (matrix_vector_mult
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => d0 j i)
- (fun i : {n' : nat | n' < m} => (grad1 i + grad2 i)%R)) in IHdf2.
- * rewrite eqq3, eqq5, eqq7 in IHdf2.
- unfold lift, lift2 in IHdf2; inversion IHdf2.
- f_equal.
- rewrite (split_subvar env env0 d val1); trivial.
- rewrite (split_subvar env1 env2 val val2); trivial.
- rewrite (split_subvar env3 env4 val0 val3); trivial.
- rewrite H7, H8; lra.
- * apply FunctionalExtensionality.functional_extensionality; intros.
- unfold matrix_vector_mult.
- rewrite vsum_plus; f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- simpl; lra.
- + apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- - Case "MatrixVectorAdd"%string.
- destruct H2.
- simpl; intros.
- repeat simpl_closed_backprop.
- simpler2.
- specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- simpl in IHdf1.
- rewrite eqq,eqq0,eqq1,eqq2,eqq3,eqq4 in IHdf1.
- unfold lift, lift2 in IHdf1.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) H); intro.
- case_eq (vartlookup env0 (s, DTfloat)); [intros|tauto].
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat) H0); intro.
- case_eq (vartlookup env1 (s, DTfloat)); [intros|tauto].
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq (s, DTfloat) H1); intro.
- case_eq (vartlookup env (s, DTfloat)); [intros|tauto].
- generalize (backprop_mat_grad_sum_list_env_iter
- σ df2 s env0 env1 env (transpose grad1) (transpose grad2)
- d d0 d1 (bounded_seq0 n)); intros.
- simpl in H10.
- specialize (H10 H3).
- cut_to H10; trivial.
- + do 3 match_option.
- * rewrite eqq6, eqq7 in H10.
- replace (fun (j : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env
- (fun k : {n' : nat | n' < m} => ((@transpose R m n grad1 j k) +
- (@transpose R m n grad2 j k))%R))
- with
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env
- (transpose
- (fun (i0 : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (grad1 i0 j + grad2 i0 j)%R) i)) in H10.
- -- rewrite eqq5 in H10.
- unfold lift, lift2 in H10.
- unfold lift, lift2; f_equal.
- inversion IHdf1; inversion H10.
- rewrite (split_subvar env d2 val d1); trivial.
- rewrite (split_subvar env0 d3 val0 d); trivial.
- rewrite (split_subvar env1 d4 val1 d0); trivial.
- rewrite H12, H13; lra.
- -- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- f_equal.
- * assert (list_env_iter
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad2 i))
- (Some env1) (bounded_seq0 n) <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- * assert (list_env_iter
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad1 i))
- (Some env0) (bounded_seq0 n) <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- * assert (list_env_iter
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env (transpose grad1 i))
- (Some env0) (bounded_seq0 n) <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- * assert (list_env_iter
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ df2 env
- (transpose
- (fun (i0 : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (grad1 i0 j + grad2 i0 j)%R) i)) (Some env) (bounded_seq0 n) <> None).
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- + intros.
- apply IHdf2; trivial.
- - Case "MatrixMult"%string.
- destruct H2.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df1); intros.
- specialize (H4 H2).
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H5 H3).
- match_option; [|tauto].
- unfold lift, lift2.
- repeat simpl_closed_backprop.
- simpler2.
- specialize
- (IHdf1
- (matrix_mult grad1
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < p}) => d1 j i))
- (matrix_mult grad2
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < p}) => d1 j i))
- grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- rewrite eqq, eqq8, eqq9 in IHdf1.
- specialize (IHdf2
- (matrix_mult (fun (i : {n' : nat | n' < p})
- (j : {m' : nat | m' < m}) => d0 j i)
- grad1)
- (matrix_mult (fun (i : {n' : nat | n' < p})
- (j : {m' : nat | m' < m}) => d0 j i)
- grad2)
- env1 env3 env).
- simpl in IHdf1.
- replace
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < p}) =>
- (matrix_mult grad1
- (fun (i0 : {n' : nat | (n' < n)%nat}) (j0 : {m' : nat | (m' < p)%nat}) =>
- d1 j0 i0) i j +
- matrix_mult grad2
- (fun (i0 : {n' : nat | (n' < n)%nat}) (j0 : {m' : nat | (m' < p)%nat}) =>
- d1 j0 i0) i j)%R)
- with
- (matrix_mult
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (grad1 i j + grad2 i j)%R)
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < p}) => d1 j i)) in IHdf1.
- + rewrite eqq2, eqq4, eqq6 in IHdf1.
- unfold lift, lift2 in IHdf1; simpl in IHdf1.
- inversion IHdf1.
- cut_to IHdf2; simpler2.
- simpl in IHdf2.
- rewrite eqq5, eqq7 in IHdf2.
- replace
- (fun (i : {n' : nat | n' < p}) (j : {m' : nat | m' < n}) =>
- (matrix_mult
- (fun (i0 : {n' : nat | (n' < p)%nat}) (j0 : {m' : nat | (m' < m)%nat}) =>
- d0 j0 i0) grad1 i j +
- matrix_mult
- (fun (i0 : {n' : nat | (n' < p)%nat}) (j0 : {m' : nat | (m' < m)%nat}) =>
- d0 j0 i0) grad2 i j)%R) with
- (matrix_mult (fun (i : {n' : nat | n' < p}) (j : {m' : nat | m' < m}) => d0 j i)
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (grad1 i j + grad2 i j)%R))
- in IHdf2.
- * rewrite eqq3 in IHdf2.
- unfold lift, lift2 in IHdf2; simpl in IHdf2.
- inversion IHdf2.
- f_equal.
- rewrite (split_subvar env env0 d val1); trivial.
- rewrite (split_subvar env1 env2 val val2); trivial.
- rewrite (split_subvar env3 env4 val0 val3); trivial.
- rewrite H7, H8; lra.
- * unfold matrix_mult.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vsum_plus; f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- simpl; lra.
- + unfold matrix_mult.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vsum_plus; f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- simpl; lra.
- - Case "VectorPlus"%string.
- destruct H2.
- unfold lift, lift2.
- repeat simpl_closed_backprop.
- specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2 grad1 grad2 env1 env3 env).
- simpler2.
- simpl in IHdf1.
- rewrite eqq1, eqq3, eqq in IHdf1.
- unfold lift, lift2 in IHdf1; simpl in IHdf1.
- inversion IHdf1.
- cut_to IHdf2; try congruence.
- simpl in IHdf2.
- rewrite eqq0, eqq2, eqq4 in IHdf2.
- unfold lift, lift2 in IHdf2; inversion IHdf2.
- f_equal.
- rewrite (split_subvar env env0 val val5); trivial.
- rewrite (split_subvar env1 env2 val0 val6); trivial.
- rewrite (split_subvar env3 env4 val1 val7); trivial.
- rewrite eqq5 in eqq8.
- rewrite eqq6 in eqq9.
- rewrite eqq7 in eqq10.
- invcs eqq8; invcs eqq9; invcs eqq10.
- rewrite H5, H6; lra.
- - Case "VectorMinus"%string.
- destruct H2.
- unfold lift, lift2.
- repeat simpl_closed_backprop.
- specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2 (fun i : {n' : nat | n' < n} => (- grad1 i)%R)
- (fun i : {n' : nat | n' < n} => (- grad2 i)%R)
- env1 env3 env).
- simpler2.
- simpl in IHdf1.
- rewrite eqq, eqq1, eqq3 in IHdf1.
- unfold lift, lift2 in IHdf1; simpl in IHdf1.
- inversion IHdf1.
- cut_to IHdf2; try congruence.
- simpl in IHdf2.
- replace (fun i : {n' : nat | n' < n} => (- grad1 i + - grad2 i)%R) with
- (fun i : {n' : nat | n' < n} => (- (grad1 i + grad2 i))%R) in IHdf2.
- rewrite eqq0, eqq2, eqq4 in IHdf2.
- unfold lift, lift2 in IHdf2; simpl in IHdf2.
- inversion IHdf2.
- f_equal.
- rewrite (split_subvar env env0 val val5); trivial.
- rewrite (split_subvar env1 env2 val0 val6); trivial.
- rewrite (split_subvar env3 env4 val1 val7); trivial.
- rewrite eqq5 in eqq8; rewrite eqq6 in eqq9; rewrite eqq7 in eqq10.
- invcs eqq8; invcs eqq9; invcs eqq10.
- rewrite H5, H6; lra.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- - Case "MatrixPlus"%string.
- destruct H2.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv
- σ df1 grad_env3
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (grad1 i j + grad2 i j)%R)
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env1 (grad1)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2 (grad2)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2 grad1 grad2 d2 d4 d0).
- rewrite eqq, eqq1, eqq3 in IHdf1.
- simpl in IHdf1.
- rewrite eqq0, eqq2, eqq4 in IHdf1.
- unfold lift, lift2 in IHdf1; simpl in IHdf1.
- inversion IHdf1.
- assert (vartlookup d2 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) H).
- assert (vartlookup d4 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s, DTfloat) H0).
- assert (vartlookup d0 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) H1).
- specialize (IHdf2 H7 H9 H10 H3).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- simpl in IHdf2.
- unfold lift, lift2; simpl.
- assert (df_eval_backprop_deriv
- σ df2 d0
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad1 i j + grad2 i j)%R) <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d2 (grad1)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- case_eq (df_eval_backprop_deriv σ df2 d2 (grad1)%R); [intros|tauto].
- assert (df_eval_backprop_deriv σ df2 d4 (grad2)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- case_eq (df_eval_backprop_deriv σ df2 d4 (grad2)%R); [intros|tauto].
- simpl in IHdf2.
- rewrite eqq8, H13, H15 in IHdf2.
- unfold lift, lift2 in IHdf2; simpl in IHdf2.
- inversion IHdf2.
- f_equal.
- rewrite (split_subvar d0 d8 d d5); trivial.
- rewrite (split_subvar d2 d9 d1 d6); trivial.
- rewrite (split_subvar d4 d10 d3 d7); trivial.
- rewrite H8, H17; lra.
- - Case "MatrixMinus"%string.
- destruct H2.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv
- σ df1 grad_env3
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad1 i j + grad2 i j)%R) <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env1 (grad1)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2 (grad2)%R <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- specialize (IHdf1 grad1 grad2 grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (- grad1 i j)%R)
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (- grad2 i j)%R)
- d2 d4 d0).
- rewrite eqq, eqq1, eqq3 in IHdf1.
- simpl in IHdf1.
- rewrite eqq0, eqq2, eqq4 in IHdf1.
- unfold lift, lift2 in IHdf1; simpl in IHdf1.
- inversion IHdf1.
- assert (vartlookup d2 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) H).
- assert (vartlookup d4 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s, DTfloat) H0).
- assert (vartlookup d0 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq0 (s, DTfloat) H1).
- specialize (IHdf2 H7 H9 H10 H3).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- unfold lift, lift2; simpl.
- assert (df_eval_backprop_deriv
- σ df2 d0
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (- (grad1 i j + grad2 i j))%R) <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv
- σ df2 d2
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad1 i j)%R)
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- case_eq (df_eval_backprop_deriv σ df2 d2
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad1 i j)%R))
- ; [intros |tauto].
- assert (df_eval_backprop_deriv σ df2 d4
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad2 i j)%R)
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- case_eq (df_eval_backprop_deriv σ df2 d4
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) => (- grad2 i j)%R))
- ; [intros | tauto].
- simpl in IHdf2.
- replace
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (- grad1 i j + - grad2 i j)%R) with
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (- (grad1 i j + grad2 i j))%R) in IHdf2.
- rewrite eqq8, H13, H15 in IHdf2.
- unfold lift, lift2 in IHdf2; simpl in IHdf2.
- inversion IHdf2.
- f_equal.
- rewrite (split_subvar d0 d8 d d5); trivial.
- rewrite (split_subvar d2 d9 d1 d6); trivial.
- rewrite (split_subvar d4 d10 d3 d7); trivial.
- rewrite H8, H17; lra.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- - Case "VectorScalMult"%string.
- destruct H2.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df1); intros.
- specialize (H4 H2).
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H5 H3).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env3
- (vsum (fun j : {n' : nat | n' < n} => (d1 j * (grad1 j + grad2 j))%R))
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env1
- (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad1 j)%R))
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2
- (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad2 j)%R))
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- specialize (IHdf1
- (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad1 j)%R))
- (vsum (fun j : {n' : nat | n' < n} => (d1 j * grad2 j)%R))
- grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2
- (fun j : {n' : nat | n' < n} => (d0 * grad1 j)%R)
- (fun j : {n' : nat | n' < n} => (d0 * grad2 j)%R)
- d4 d6 d2).
- rewrite eqq, eqq3, eqq5 in IHdf1.
- simpl in IHdf1.
- replace
- (@vsum floatish_R n (fun j : {n' : nat | (n' < n)%nat} => d1 j * grad1 j) +
- @vsum floatish_R n (fun j : {n' : nat | (n' < n)%nat} => d1 j * grad2 j))%R with
- (vsum (fun j : {n' : nat | n' < n} => (d1 j * (grad1 j + grad2 j))%R)) in IHdf1.
- + rewrite eqq2, eqq4, eqq6 in IHdf1.
- unfold lift, lift2 in IHdf1; simpl in IHdf1.
- inversion IHdf1.
- assert (vartlookup d2 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) H1).
- assert (vartlookup d4 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s, DTfloat) H).
- assert (vartlookup d6 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq6 (s, DTfloat) H0).
- specialize (IHdf2 H11 H12 H9 H3).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- unfold lift, lift2; simpl.
- assert (df_eval_backprop_deriv
- σ df2 d2
- (fun j : {n' : nat | n' < n} => (d0 * (grad1 j + grad2 j))%R)
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d4
- (fun j : {n' : nat | n' < n} => (d0 * grad1 j)%R)
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- case_eq (df_eval_backprop_deriv σ df2 d4
- (fun j : {n' : nat | n' < n} => (d0 * grad1 j)%R))
- ; [intros | tauto].
- assert (df_eval_backprop_deriv σ df2 d6
- (fun j : {n' : nat | n' < n} => (d0 * grad2 j)%R)
- <> None ).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- case_eq (df_eval_backprop_deriv σ df2 d6
- (fun j : {n' : nat | n' < n} => (d0 * grad2 j)%R))
- ; [intros | tauto].
- simpl in IHdf2.
- replace
- (fun i : {n' : nat | n' < n} => (d0 * grad1 i + d0 * grad2 i)%R) with
- (fun j : {n' : nat | n' < n} => (d0 * (grad1 j + grad2 j))%R) in IHdf2.
- * rewrite eqq10, H15, H17 in IHdf2.
- unfold lift, lift2 in IHdf2; simpl in IHdf2.
- inversion IHdf2.
- f_equal.
- rewrite (split_subvar d2 d10 d d7); trivial.
- rewrite (split_subvar d4 d11 d3 d8); trivial.
- rewrite (split_subvar d6 d12 d5 d9); trivial.
- rewrite H10, H19; lra.
- * apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- + simpl.
- rewrite vsum_plus.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- - Case "MatrixScalMult"%string.
- destruct H2.
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df1); intros.
- specialize (H4 H2).
- match_option; [|tauto].
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H5 H3).
- match_option; [|tauto].
- assert (df_eval_backprop_deriv
- σ df1 grad_env3
- (msum
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (d1 i j * (grad1 i j + grad2 i j))%R))
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env1
- (msum
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (d1 i j * grad1 i j)%R))
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df1 grad_env2
- (msum
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (d1 i j * grad2 i j)%R))
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H2.
- match_option; [|tauto].
- specialize (IHdf1
- (msum
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (d1 i j * grad1 i j)%R))
- (msum
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (d1 i j * grad2 i j)%R))
- grad_env1 grad_env2 grad_env3 H H0 H1 H2).
- specialize (IHdf2
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (grad1 i j * d0)%R)
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (grad2 i j * d0)%R)
- d4 d6 d2).
- rewrite eqq, eqq3, eqq5 in IHdf1.
- simpl in IHdf1.
- replace
- (@msum floatish_R n m
- (fun (i : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) =>
- d1 i j * grad1 i j) +
- @msum floatish_R n m
- (fun (i : {n' : nat | (n' < n)%nat}) (j : {m' : nat | (m' < m)%nat}) =>
- d1 i j * grad2 i j))%R with
- (msum
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (d1 i j * (grad1 i j + grad2 i j))%R)) in IHdf1.
- + rewrite eqq2, eqq4, eqq6 in IHdf1.
- unfold lift, lift2 in IHdf1; simpl in IHdf1.
- inversion IHdf1.
- assert (vartlookup d2 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat) H1).
- assert (vartlookup d4 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s, DTfloat) H).
- assert (vartlookup d6 (s, DTfloat) <> None).
- apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq6 (s, DTfloat) H0).
- specialize (IHdf2 H11 H12 H9 H3).
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- match_option_in IHdf2; [|tauto].
- unfold lift, lift2; simpl.
- assert (df_eval_backprop_deriv
- σ df2 d2
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- ((grad1 i j + grad2 i j) * d0)%R)
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- match_option; [|tauto].
- assert (df_eval_backprop_deriv σ df2 d4
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad1 i j * d0)%R)
- <> None).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- case_eq (df_eval_backprop_deriv σ df2 d4
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad1 i j * d0)%R))
- ; [intros | tauto].
- assert (df_eval_backprop_deriv σ df2 d6
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad2 i j * d0)%R)
- <> None ).
- apply backprop_deriv_fully_closed_not_none.
- apply H3.
- case_eq (df_eval_backprop_deriv σ df2 d6
- (fun (i : {n' : nat | n' < n})
- (j : {m' : nat | m' < m}) => (grad2 i j * d0)%R))
- ; [intros | tauto].
- simpl in IHdf2.
- replace
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (grad1 i j * d0 + grad2 i j * d0)%R) with
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- ((grad1 i j + grad2 i j) * d0)%R) in IHdf2.
- * rewrite eqq10, H15, H17 in IHdf2.
- unfold lift, lift2 in IHdf2; simpl in IHdf2.
- inversion IHdf2.
- f_equal.
- rewrite (split_subvar d2 d10 d d7); trivial.
- rewrite (split_subvar d4 d11 d3 d8); trivial.
- rewrite (split_subvar d6 d12 d5 d9); trivial.
- rewrite H10, H19; lra.
- * apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- + unfold msum.
- rewrite vsum_plus.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vmap_nth.
- rewrite vmap_nth.
- rewrite vmap_nth.
- rewrite vsum_plus.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- - Case "VectorApply"%string.
- destruct H2.
- simpler2.
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H4 H3).
- match_option.
- match_option.
- + match_option.
- * match_option.
- -- unfold lift.
- repeat simpl_closed_backprop.
- unfold lift2.
- specialize (apply vectoro_to_ovector_forall_some_f eqq3);intros.
- specialize (apply vectoro_to_ovector_forall_some_f eqq4);intros.
- specialize (apply vectoro_to_ovector_forall_some_f eqq5);intros.
- specialize (IHdf2 v1 v2 grad_env1 grad_env2 grad_env3).
- cut_to IHdf2; try congruence.
- rewrite eqq, eqq0, eqq1, eqq7, eqq8 in IHdf2; simpl in IHdf2.
- replace (fun i : {n' : nat | n' < n} => (v1 i + v2 i)%R) with v0 in IHdf2.
- rewrite eqq6 in IHdf2.
- unfold lift, lift2 in IHdf2.
- apply IHdf2.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H5 x);rewrite vmap_nth in H5; simpl in H5.
- specialize (H6 x);rewrite vmap_nth in H6; simpl in H6.
- specialize (H7 x);rewrite vmap_nth in H7; simpl in H7.
- match_option_in H5; invcs H5.
- match_option_in H6; invcs H6.
- match_option_in H7; invcs H7.
- rewrite eqq9 in eqq10; invcs eqq10.
- rewrite eqq9 in eqq11; invcs eqq11.
- lra.
- -- specialize (apply vectoro_to_ovector_exists_None eqq5); intros.
- destruct H5.
- rewrite vmap_nth in e; simpl in e.
- match_option_in e.
- assert (df_eval [mk_env_entry (v, DTfloat) (d x)] (df_deriv df1 (v, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- * specialize (apply vectoro_to_ovector_exists_None eqq4); intros.
- destruct H5.
- rewrite vmap_nth in e; simpl in e.
- match_option_in e.
- assert (df_eval [mk_env_entry (v, DTfloat) (d x)] (df_deriv df1 (v, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- + specialize (apply vectoro_to_ovector_exists_None eqq3); intros.
- destruct H5.
- rewrite vmap_nth in e; simpl in e.
- match_option_in e.
- assert (df_eval [mk_env_entry (v, DTfloat) (d x)] (df_deriv df1 (v, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- - Case "MatrixApply"%string.
- destruct H2.
- simpler2.
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H4 H3).
- match_option.
- match_option.
- + match_option.
- * match_option.
- -- unfold lift.
- repeat simpl_closed_backprop.
- unfold lift2.
- unfold matrixo_to_omatrix in *.
- specialize (apply vectoro_to_ovector_forall_some_f eqq3);intros.
- specialize (apply vectoro_to_ovector_forall_some_f eqq4);intros.
- specialize (apply vectoro_to_ovector_forall_some_f eqq5);intros.
- specialize (IHdf2 m1 m2 grad_env1 grad_env2 grad_env3).
- cut_to IHdf2; try congruence.
- rewrite eqq, eqq0, eqq1, eqq7, eqq8 in IHdf2; simpl in IHdf2.
- replace (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (m1 i j + m2 i j)%R) with m0 in IHdf2.
- rewrite eqq6 in IHdf2.
- unfold lift, lift2 in IHdf2.
- apply IHdf2.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H5 x); specialize (H6 x); specialize (H7 x).
- unfold mmap in H5; unfold mmap in H6; unfold mmap in H7.
- specialize (apply vectoro_to_ovector_forall_some_f H5);intros.
- specialize (apply vectoro_to_ovector_forall_some_f H6);intros.
- specialize (apply vectoro_to_ovector_forall_some_f H7);intros.
- specialize (H8 x0); do 2 rewrite vmap_nth in H8.
- specialize (H9 x0); do 2 rewrite vmap_nth in H9.
- specialize (H10 x0); do 2 rewrite vmap_nth in H10.
- rewrite matrix_zip_m_n in H8.
- rewrite matrix_zip_m_n in H9.
- rewrite matrix_zip_m_n in H10.
- match_option_in H8.
- rewrite eqq9 in H9; rewrite eqq9 in H10.
- invcs H8; invcs H9; invcs H10.
- lra.
- -- specialize (apply vectoro_to_ovector_exists_None eqq5); intros; destruct H5.
- specialize (apply vectoro_to_ovector_exists_None e); intros; destruct H5.
- unfold mmap in e0.
- do 2 rewrite vmap_nth in e0; simpl in e0.
- rewrite matrix_zip_m_n in e0.
- match_option_in e0.
- assert (df_eval [mk_env_entry (v, DTfloat) (d x x0)] (df_deriv df1 (v, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- * specialize (apply vectoro_to_ovector_exists_None eqq4); intros; destruct H5.
- specialize (apply vectoro_to_ovector_exists_None e); intros; destruct H5.
- unfold mmap in e0.
- do 2 rewrite vmap_nth in e0; simpl in e0.
- rewrite matrix_zip_m_n in e0.
- match_option_in e0.
- assert (df_eval [mk_env_entry (v, DTfloat) (d x x0)] (df_deriv df1 (v, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- + specialize (apply vectoro_to_ovector_exists_None eqq3); intros; destruct H5.
- specialize (apply vectoro_to_ovector_exists_None e); intros; destruct H5.
- unfold mmap in e0.
- do 2 rewrite vmap_nth in e0; simpl in e0.
- rewrite matrix_zip_m_n in e0.
- match_option_in e0.
- assert (df_eval [mk_env_entry (v, DTfloat) (d x x0)] (df_deriv df1 (v, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- - Case "VLossfun"%string.
- destruct H2.
- simpler2.
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H4 H3).
- match_option.
- match_option.
- + match_option.
- * match_option.
- -- unfold lift.
- repeat simpl_closed_backprop.
- unfold lift2.
- specialize (apply vectoro_to_ovector_forall_some_f eqq3);intros.
- specialize (apply vectoro_to_ovector_forall_some_f eqq4);intros.
- specialize (apply vectoro_to_ovector_forall_some_f eqq5);intros.
- specialize (IHdf2 v0 v3 grad_env1 grad_env2 grad_env3).
- cut_to IHdf2; try congruence.
- rewrite eqq, eqq0, eqq1 in IHdf2.
- rewrite eqq7, eqq8 in IHdf2.
- simpl in IHdf2.
- replace (fun i : {n' : nat | n' < n} => (v0 i + v3 i)%R) with v in IHdf2.
- rewrite eqq6 in IHdf2.
- unfold lift, lift2 in IHdf2.
- apply IHdf2.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H5 x);rewrite vmap_nth in H5; simpl in H5.
- specialize (H6 x);rewrite vmap_nth in H6; simpl in H6.
- specialize (H7 x);rewrite vmap_nth in H7; simpl in H7.
- match_option_in H5.
- rewrite eqq9 in H6; rewrite eqq9 in H7.
- invcs H5; invcs H6; invcs H7.
- lra.
- -- specialize (apply vectoro_to_ovector_exists_None eqq5); intros.
- destruct H5.
- rewrite vmap_nth in e; simpl in e.
- match_option_in e.
- assert ( df_eval [mk_env_entry (v1, DTfloat) (d x);
- mk_env_entry (v2, DTfloat) (r x)]
- (df_deriv df1 (v1, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- * specialize (apply vectoro_to_ovector_exists_None eqq4); intros.
- destruct H5.
- rewrite vmap_nth in e; simpl in e.
- match_option_in e.
- assert (df_eval [mk_env_entry (v1, DTfloat) (d x);
- mk_env_entry (v2, DTfloat) (r x)]
- (df_deriv df1 (v1, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- + specialize (apply vectoro_to_ovector_exists_None eqq3); intros.
- destruct H5.
- rewrite vmap_nth in e; simpl in e.
- match_option_in e.
- assert (df_eval [mk_env_entry (v1, DTfloat) (d x); mk_env_entry (v2, DTfloat) (r x)]
- (df_deriv df1 (v1, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- - Case "MLossfun"%string.
- destruct H2.
- simpler2.
- generalize (eval_fully_closed_not_none σ df2); intros.
- specialize (H4 H3).
- match_option.
- match_option.
- + match_option.
- * match_option.
- -- unfold lift.
- repeat simpl_closed_backprop.
- unfold lift2.
- specialize (apply vectoro_to_ovector_forall_some_f eqq3);intros.
- specialize (apply vectoro_to_ovector_forall_some_f eqq4);intros.
- specialize (apply vectoro_to_ovector_forall_some_f eqq5);intros.
- specialize (IHdf2 m1 m2 grad_env1 grad_env2 grad_env3).
- cut_to IHdf2; try congruence.
- rewrite eqq, eqq0, eqq1 in IHdf2.
- rewrite eqq7, eqq8 in IHdf2.
- simpl in IHdf2.
- replace (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (m1 i j + m2 i j)%R) with m0 in IHdf2.
- rewrite eqq6 in IHdf2.
- unfold lift, lift2 in IHdf2.
- apply IHdf2.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H5 x); simpl in H5.
- specialize (H6 x); simpl in H6.
- specialize (H7 x); simpl in H7.
- specialize (apply vectoro_to_ovector_forall_some_f H5);intros.
- specialize (apply vectoro_to_ovector_forall_some_f H6);intros.
- specialize (apply vectoro_to_ovector_forall_some_f H7);intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H8 x0); unfold mmap in H8;do 2 rewrite vmap_nth in H8; simpl in H8.
- specialize (H9 x0); unfold mmap in H9;do 2 rewrite vmap_nth in H9; simpl in H9.
- specialize (H10 x0); unfold mmap in H10;do 2 rewrite vmap_nth in H10; simpl in H10.
- rewrite matrix_zip_m_n in H8.
- rewrite matrix_zip_m_n in H9.
- rewrite matrix_zip_m_n in H10.
- match_option_in H8.
- rewrite eqq9 in H9; rewrite eqq9 in H10.
- invcs H8; invcs H9; invcs H10.
- lra.
- -- specialize (apply vectoro_to_ovector_exists_None eqq5); intros.
- destruct H5.
- specialize (apply vectoro_to_ovector_exists_None e); intros.
- destruct H5.
- unfold mmap in e0; do 2 rewrite vmap_nth in e0; simpl in e0.
- rewrite matrix_zip_m_n in e0.
- match_option_in e0.
- assert ( df_eval [mk_env_entry (v1, DTfloat) (d x x0);
- mk_env_entry (v2, DTfloat) (r x x0)]
- (df_deriv df1 (v1, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- * specialize (apply vectoro_to_ovector_exists_None eqq4); intros.
- destruct H5.
- specialize (apply vectoro_to_ovector_exists_None e); intros.
- destruct H5.
- unfold mmap in e0; do 2 rewrite vmap_nth in e0; simpl in e0.
- rewrite matrix_zip_m_n in e0.
- match_option_in e0.
- assert (df_eval [mk_env_entry (v1, DTfloat) (d x x0);
- mk_env_entry (v2, DTfloat) (r x x0)]
- (df_deriv df1 (v1, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- + specialize (apply vectoro_to_ovector_exists_None eqq3); intros.
- destruct H5.
- specialize (apply vectoro_to_ovector_exists_None e); intros.
- destruct H5.
- unfold mmap in e0; do 2 rewrite vmap_nth in e0; simpl in e0.
- rewrite matrix_zip_m_n in e0.
- match_option_in e0.
- assert (df_eval [mk_env_entry (v1, DTfloat) (d x x0);
- mk_env_entry (v2, DTfloat) (r x x0)]
- (df_deriv df1 (v1, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- now apply fully_closed_deriv.
- tauto.
- Qed.
-
- Lemma vectoro_to_ovector_eNone_None {A n} {vo:Vector (option A) n} :
- {i | vo i = None} ->
- vectoro_to_ovector vo = None.
- Proof.
- intros.
- destruct H.
- now apply (vectoro_to_ovector_None_None x).
- Qed.
-
- Definition ConstSplitVector {T} (middle theend:nat) (part1 part2:T) : (Vector T theend) :=
- fun (i: {n':nat | n' < theend}%nat) =>
- if lt_dec (proj1_sig i) middle then part1 else part2.
-
- Definition mergeVectorZero {T} {n} (middle:nat) (part1 : Vector T n) (c:T) : (Vector T n) :=
- fun (i: {n':nat | n' < n}%nat) =>
- if lt_dec (proj1_sig i) middle then (part1 i) else c.
-
- Definition scaleUnitVector {T} (n:nat) (j : {n':nat | (n' < n)%nat}) (c:T) (zero:T) : Vector T n :=
- fun i => if (proj1_sig i) == (proj1_sig j) then c else zero%R.
-
- Lemma ConstSplitVectorSzero bound n (pf:bound < n) :
- (ConstSplitVector (S bound) n 1%R 0%R) =
- dfti_gen_plus (T:=DTVector n) (ConstSplitVector bound n 1%R 0%R) (UnitVector n (exist _ bound pf)).
- Proof.
- simpl.
- apply functional_extensionality.
- intros.
- unfold ConstSplitVector, UnitVector, equiv_dec, nat_eq_eqdec; simpl.
- destruct x as [x pff]; simpl.
- destruct (lt_dec x (S bound))
- ; destruct (lt_dec x bound)
- ; destruct (Nat.eq_dec x bound)
- ; try lia; try lra.
- Qed.
-
- Lemma mergeVectorSzero {n} (bound:nat) (pf:bound < n) (part1 : Vector R n):
- let ind := (exist _ bound pf) in
- (@mergeVectorZero (@float floatish_R) n (S bound) part1 0%R) =
- dfti_gen_plus (T:=DTVector n) (@mergeVectorZero (@float floatish_R) n bound part1 0%R)
- (scaleUnitVector n ind (part1 ind) 0%R).
- Proof.
- simpl.
- apply functional_extensionality.
- intros.
- unfold mergeVectorZero, scaleUnitVector, equiv_dec, nat_eq_eqdec; simpl.
- destruct x as [x pff]; simpl.
- destruct (lt_dec x (S bound))
- ; destruct (lt_dec x bound)
- ; destruct (Nat.eq_dec x bound)
- ; try lia; try lra.
- subst; simpl.
- ring_simplify.
- erewrite index_pf_irrel; eauto.
- Qed.
-
- Lemma mergeVectorSzero_mat {n m} (bound:nat) pf (part1 : Matrix float n m) :
- let ind := (exist _ bound pf) in
- (@mergeVectorZero (Vector float m) n (S bound) part1 (ConstVector m 0%R)) =
- dfti_gen_plus (T:=DTMatrix n m) (@mergeVectorZero (Vector float m) n bound part1 (ConstVector m 0%R))
- (scaleUnitVector (T:=Vector float m) n ind (part1 ind) (ConstVector m 0%R)).
- Proof.
- simpl.
- do 2 (apply functional_extensionality; intros).
- unfold mergeVectorZero, scaleUnitVector, ConstVector, equiv_dec, nat_eq_eqdec; simpl.
- destruct x as [x pff]; simpl.
- destruct (lt_dec x (S bound))
- ; destruct (lt_dec x bound)
- ; destruct (Nat.eq_dec x bound)
- ; simpl
- ; try lia; try lra.
- subst; simpl.
- rewrite Rplus_0_l.
- erewrite index_pf_irrel; eauto.
- Qed.
-
- Lemma vsum_alt_eq {m:nat} (v:Vector R m) : vsum v = vector_fold_right Fplus 0%R v.
- Proof.
- apply vector_fold_right1_as_vector_fold_right.
- unfold Datatypes.id; simpl; intros; lra.
- Qed.
-
- Lemma vsum_cons {m:nat} x (v:Vector R m) :
- vsum (vcons x v) = (x + vsum v)%R.
- Proof.
- repeat rewrite vsum_alt_eq.
- apply vector_fold_right_vcons.
- Qed.
-
- Lemma constSplitVectorZero {n} :
- ConstSplitVector 0 n 1%R 0%R = ConstVector n 0%R.
- Proof.
- unfold ConstSplitVector, ConstVector; simpl.
- now apply functional_extensionality.
- Qed.
-
- Lemma df_eval_backprop_delta_by_unit_partvec_bounded {n} bound (pf:(bound<=n)%nat)
- (σ:df_env) (df:DefinedFunction UnitAnn (DTVector n)) (s: SubVar) grad_env (grad d:Vector float n):
-
- fully_closed_over df
- (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) ->
- vartlookup grad_env (s, DTfloat) <> None ->
-
- (forall i : {n' : nat | n' < n},
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- (scaleUnitVector n i (grad i) 0%R)) = Some (d i)) ->
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- (mergeVectorZero bound grad 0%R)) = Some (vsum (vfirstn d bound pf)).
- Proof.
- intros closed lo fa.
- induction bound.
- - replace (mergeVectorZero 0 grad 0%R) with (scalarMult (DTVector n) 0%R (mergeVectorZero 0 grad 0%R)).
- + erewrite scalarMult_backprop_grad_scalar; try eassumption.
- * simpl.
- unfold df_eval_backprop_delta.
- simpler2.
- unfold lift.
- simpl_closed_backprop.
- f_equal.
- vm_compute; lra.
- * apply backprop_deriv_fully_closed_not_none; trivial.
- * apply backprop_deriv_fully_closed_not_none; trivial.
- + unfold scalarMult, mergeVectorZero.
- apply functional_extensionality; intros; simpl.
- lra.
- - assert (pf2:bound < n) by lia.
- assert (pf3:bound <= n) by lia.
- rewrite (mergeVectorSzero _ pf2 _ ).
- erewrite backprop_grad_sum; try eassumption.
- specialize (IHbound pf3).
- rewrite IHbound.
- rewrite fa.
- simpl.
- f_equal.
- destruct n; [lia | ].
- generalize (vector_Sn_split (vfirstn d (S bound) pf)); intros eqq1.
- apply vec_eq_eq in eqq1.
- simpl in *.
- rewrite eqq1.
- erewrite vfirstn_vdrop_last.
- erewrite vlast_vfirstn.
- rewrite vsum_cons.
- rewrite Rplus_comm.
- f_equal.
- Qed.
-
- Lemma df_eval_backprop_delta_by_unit_partvec {n}
- (σ:df_env) (df:DefinedFunction UnitAnn (DTVector n)) (s: SubVar) grad_env (grad d:Vector float n):
-
- fully_closed_over df
- (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) ->
- vartlookup grad_env (s, DTfloat) <> None ->
-
- (forall i : {n' : nat | n' < n},
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- (scaleUnitVector n i (grad i) 0%R)) = Some (d i)) ->
- df_eval_backprop_delta σ df (s, DTfloat) grad_env grad =
- Some (vsum d).
- Proof.
- intros.
- replace (grad) with (mergeVectorZero n grad 0%R).
- - erewrite df_eval_backprop_delta_by_unit_partvec_bounded; try eassumption.
- now rewrite vfirstn_eq.
- - apply functional_extensionality; unfold mergeVectorZero; intros [x pff]; simpl.
- destruct (lt_dec x n); trivial; lia.
- Unshelve.
- lia.
- Qed.
-
- Lemma df_eval_backprop_delta_by_unit_parts {n}
- (σ:df_env) (df:DefinedFunction UnitAnn (DTVector n)) (s: SubVar) grad_env (d:Vector float n):
-
- fully_closed_over df
- (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) ->
- vartlookup grad_env (s, DTfloat) <> None ->
-
- (forall i : {n' : nat | n' < n},
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
-
- (UnitVector n i)) = Some (d i)) ->
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- (ConstVector n 1%R)) = Some (vsum d).
- Proof.
- intros.
- replace (ConstVector n 1%R) with (mergeVectorZero n (ConstVector n 1%R) 0%R).
- - erewrite df_eval_backprop_delta_by_unit_partvec_bounded; try eassumption.
- now rewrite vfirstn_eq.
- - apply functional_extensionality; unfold mergeVectorZero, ConstVector; intros [x pff]; simpl.
- destruct (lt_dec x n); trivial; lia.
- Unshelve.
- lia.
- Qed.
-
- Lemma scalarMult_mult {T} a b grad : scalarMult T a (scalarMult T b grad) = scalarMult T (a*b)%R grad.
- Proof.
- destruct T; simpl.
- - lra.
- - apply vec_eq_eq; intros ?; lra.
- - do 2 (apply vec_eq_eq; intros ?); lra.
- Qed.
-
- Corollary scalarMult_backprop_grad0 {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (s: SubVar) (grad_env :df_env) (grad : definition_function_types_interp T) :
- let v := (s, DTfloat) in
- vartlookup grad_env v <> None ->
- df_eval_backprop_deriv σ df grad_env (scalarMult T 0%R grad) <> None ->
- df_eval_backprop_delta σ df v grad_env (scalarMult T 0%R grad) = Some 0%R.
- Proof.
- simpl; intros.
- (* This allows us to drop the (unneeded) assumption
- df_eval_backprop_deriv σ df grad_env1 grad <> None
- *)
- replace (scalarMult T 0%R grad) with (scalarMult T 0%R (scalarMult T 0%R grad)).
- - erewrite scalarMult_backprop_grad_scalar; try eassumption.
- + simpl.
- unfold df_eval_backprop_delta.
- simpler2; simpl.
- destruct (df_eval_backprop_deriv σ df grad_env (scalarMult T 0%R grad)); simpl; [| congruence].
- f_equal; lra.
- + now rewrite scalarMult_mult, Rmult_0_l.
- - now rewrite scalarMult_mult, Rmult_0_l.
- Qed.
-
- Lemma scaleUnitVec_vec_plus_distr m n i (x y:Vector float m) c :
- vec_eq c (dfti_gen_plus (T:=DTVector m) c c) ->
- (scaleUnitVector n i (dfti_gen_plus (T:=DTVector m) x y) c) =
- (dfti_gen_plus (T:=DTMatrix n m) (scaleUnitVector n i x c) (scaleUnitVector n i y c)).
- Proof.
- intros fa.
- do 2 (apply functional_extensionality; intro).
- unfold scaleUnitVector; simpl.
- match_destr.
- now specialize (fa x1); simpl in fa.
- Qed.
-
-
- Lemma df_eval_backprop_delta_by_unit_partmat_outer_bounded {n m} bound (pf:(bound<=n)%nat)
- (σ:df_env) (df:DefinedFunction UnitAnn (DTMatrix n m)) (s: SubVar) grad_env (grad d:Matrix float n m):
-
-
- fully_closed_over df
- (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) ->
- vartlookup grad_env (s, DTfloat) <> None ->
-
- (forall (i : {n' : nat | n' < n}),
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- (scaleUnitVector n i (grad i) (ConstVector m 0%R))) = Some (vsum (d i))) ->
-
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- (mergeVectorZero bound grad (ConstVector m 0%R))) = Some (vsum (vfirstn (vmap vsum d) bound pf)).
- Proof.
- intros closed lo fa.
- induction bound.
- - rewrite vfirstn0, vsum_nil.
-
- replace (mergeVectorZero 0 grad (ConstVector m 0%R)) with (scalarMult (DTMatrix n m) 0%R (mergeVectorZero 0 grad (ConstVector m 0%R))).
- + apply scalarMult_backprop_grad0; simpl in *; trivial.
- apply backprop_deriv_fully_closed_not_none; trivial.
- + unfold scalarMult, ConstVector, mergeVectorZero.
- do 2 (apply functional_extensionality; intros); simpl.
- lra.
- - assert (pf2:bound < n) by lia.
- assert (pf3:bound <= n) by lia.
- simpl.
- generalize (mergeVectorSzero_mat (m:=m) _ pf2 grad ); intros HH; unfold Vector in HH.
- simpl float in *.
- rewrite HH; clear HH.
- erewrite backprop_grad_sum; try eassumption.
- specialize (IHbound pf3).
- rewrite IHbound.
- rewrite fa.
- simpl.
- f_equal.
- generalize (vector_Sn_split (vfirstn (vmap vsum d) (S bound) pf)); intros eqq1.
- apply vec_eq_eq in eqq1.
- simpl in *.
- rewrite eqq1.
- erewrite vfirstn_vdrop_last.
- erewrite vlast_vfirstn.
- rewrite vsum_cons.
- rewrite Rplus_comm.
- f_equal.
- rewrite vmap_nth.
- eauto.
- Qed.
-
- Lemma df_eval_backprop_delta_by_unit_partmat_inner_bounded {n m} bound (pf:(bound<=m)%nat)
- (σ:df_env) (df:DefinedFunction UnitAnn (DTMatrix n m)) (s: SubVar) grad_env (grad d:Matrix float n m) i:
-
-
- fully_closed_over df
- (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ) ->
- vartlookup grad_env (s, DTfloat) <> None ->
-
- (forall j : {n' : nat | n' < m},
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- ((scaleUnitVector n i
- (scaleUnitVector m j (grad i j) 0%R) (ConstVector m 0%R)))) = Some (d i j)) ->
-
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- (scaleUnitVector n i (mergeVectorZero bound (grad i) 0%R) (ConstVector m 0%R))) = Some (vsum (vfirstn (d i) bound pf)).
- Proof.
- intros closed lo fa.
- induction bound.
- - rewrite vfirstn0, vsum_nil.
- replace (scaleUnitVector n i (mergeVectorZero 0 (grad i) 0%R) (ConstVector m 0%R)) with (scalarMult (DTMatrix n m) 0%R (mergeVectorZero 0 grad (ConstVector m 0%R))).
- + apply scalarMult_backprop_grad0; simpl in *; trivial.
- apply backprop_deriv_fully_closed_not_none; trivial.
- + unfold scalarMult, ConstVector, mergeVectorZero, scaleUnitVector.
- do 2 (apply functional_extensionality; intros); simpl.
- match_destr; lra.
- - assert (pf2:bound < m) by lia.
- assert (pf3:bound <= m) by lia.
- simpl.
- generalize (mergeVectorSzero _ pf2 (grad i) ); intros HH; unfold Vector in HH.
- simpl float in *.
- rewrite HH; clear HH.
- rewrite (scaleUnitVec_vec_plus_distr m n i) by (intros ?; unfold ConstVector; simpl; lra).
- erewrite backprop_grad_sum; try eassumption.
- specialize (IHbound pf3).
- simpl in *.
- rewrite IHbound.
- simpl.
- rewrite fa.
- simpl.
- f_equal.
- generalize (vector_Sn_split (vfirstn (d i) (S bound) pf)); intros eqq1.
- apply vec_eq_eq in eqq1.
- simpl in *.
- rewrite eqq1.
- erewrite vfirstn_vdrop_last.
- erewrite vlast_vfirstn.
- rewrite vsum_cons.
- rewrite Rplus_comm.
- f_equal.
- Qed.
-
- Lemma df_eval_backprop_delta_by_unit_partmat {n m}
- (σ:df_env) (df:DefinedFunction UnitAnn (DTMatrix n m)) (s: SubVar) grad_env
- (grad d:Matrix float n m):
- fully_closed_over
- df
- (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve)
- σ) ->
- vartlookup grad_env (s, DTfloat) <> None ->
- (forall (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) ,
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- (scalarMult (DTMatrix n m) (grad i j)
- (UnitMatrix n m i j))) =
- Some (d i j)) ->
- df_eval_backprop_delta σ df (s, DTfloat) grad_env grad =
- Some (msum d).
- Proof.
- intros closed lo fa.
- replace (grad) with (mergeVectorZero n grad (ConstVector m 0%R)).
- - erewrite df_eval_backprop_delta_by_unit_partmat_outer_bounded; try eassumption.
- + now rewrite vfirstn_eq.
- + intros i.
- specialize (fa i).
- simpl in fa.
- replace (grad i) with (mergeVectorZero m (grad i) 0%R).
- 2: {
- unfold mergeVectorZero; simpl; apply functional_extensionality; intros; destruct x.
- simpl.
- match_destr; lia.
- }
-
- erewrite (df_eval_backprop_delta_by_unit_partmat_inner_bounded m (le_refl m) σ df s grad_env)
- ; try eassumption.
- * now rewrite vfirstn_eq.
- * intros j.
- specialize (fa j); simpl in *.
- rewrite <- fa.
- f_equal.
- do 2 (apply functional_extensionality; intros).
- unfold scaleUnitVector, ConstVector, UnitMatrix; simpl.
- repeat match_destr; lra.
- - apply functional_extensionality; unfold mergeVectorZero; intros [x pff]; simpl.
- destruct (lt_dec x n); trivial; lia.
- Unshelve.
- lia.
- Qed.
-
- Lemma df_eval_backprop_delta_by_unit_parts_mat {n m}
- (σ:df_env) (df:DefinedFunction UnitAnn (DTMatrix n m)) (s: SubVar) grad_env
- (d:Matrix float n m):
- fully_closed_over
- df
- (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve)
- σ) ->
- vartlookup grad_env (s, DTfloat) <> None ->
-
- (forall (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) ,
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
-
- (UnitMatrix n m i j)) = Some (d i j)) ->
- (df_eval_backprop_delta σ df (s, DTfloat) grad_env
- (ConstMatrix n m 1%R)) = Some (msum d).
- Proof.
- intros.
- apply df_eval_backprop_delta_by_unit_partmat; trivial.
- intros.
- rewrite scalarMult_backprop_grad_scalar with (grad_env2 := grad_env); trivial.
- unfold lift.
- rewrite H1.
- f_equal.
- unfold scalarMult, ConstMatrix; simpl; lra.
- apply backprop_deriv_fully_closed_not_none; trivial.
- apply backprop_deriv_fully_closed_not_none; trivial.
- Qed.
-
- Corollary scalarMult_backprop_list_env_iter_grad0 {T} (σ:df_env) (s: SubVar) (grad_env :df_env) (grad : definition_function_types_interp T) old n x l :
- let v := (s, DTfloat) in
- (forall j : {n' : nat | n' < n},
- fully_closed_over (x j)
- (map (fun ve : {v : var_type & definition_function_types_interp (snd v)} => projT1 ve) σ)) ->
- vartlookup grad_env v = Some old ->
- lift (fun e => subvar v e old)
- (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv (Ann:=UnitAnn) σ (x i) env (scalarMult T 0%R grad)) (Some grad_env) l) = Some 0%R.
- Proof.
- simpl; intros.
- revert grad_env old H0.
- induction l; simpl; intros.
- - unfold subvar; simpl.
- rewrite H0; f_equal.
- lra.
- - unfold lift in *.
- case_eq (df_eval_backprop_deriv σ (x a) grad_env (scalarMult T 0%R grad))
- ; intros.
- + apply IHl.
- * generalize (scalarMult_backprop_grad0 σ (x a) s grad_env grad); intros HH.
- simpl in HH.
- cut_to HH; try congruence.
- unfold df_eval_backprop_delta in HH.
- rewrite H0 in HH.
- rewrite H1 in HH.
- simpl in HH.
- invcs HH.
- unfold subvar in H3; simpl in H3.
- simpler2.
- rewrite eqq in eqq0; invcs eqq0; subst.
- f_equal.
- lra.
- + eelim backprop_deriv_fully_closed_not_none; eauto.
- Qed.
-
- Lemma list_env_iter_gen_delta0 {n} (s: SubVar) (init_env : df_env)
- (f: {n' : nat | n' < n} -> df_env -> option df_env) (l : list {n' : nat | n' < n}):
- let v := (s,DTfloat) in
- vartlookup init_env v <> None ->
- (forall (i : {n' : nat | n' < n}) (env : df_env),
- (f i env) <> None /\
- (vartlookup env v <> None ->
- match f i env with
- | Some xenv => vartlookup xenv v <> None
- | _ => True
- end ) /\
- (In i l -> vartlookup env v <> None ->
- lift2 (fun e val => subvar v e val) (f i env)
- (vartlookup env v) = Some 0%R)) ->
- lift2 (fun e old => subvar v e old)
- (list_env_iter f (Some init_env) l)
- (vartlookup init_env v) = Some 0%R.
- Proof.
- simpl; intros.
- revert init_env H H0.
- induction l.
- - intros.
- simpl; f_equal.
- unfold subvar; simpl.
- match_option; [|tauto].
- f_equal.
- lra.
- - intros.
- assert (H0c := H0).
- specialize (H0 a init_env).
- destruct H0; destruct H1.
- simpl.
- case_eq (f a init_env); [intros|tauto].
- specialize (IHl d).
- replace (vartlookup init_env (s, DTfloat)) with (vartlookup d (s, DTfloat)).
- + apply IHl.
- * specialize (H1 H).
- now rewrite H3 in H1.
- * intros.
- specialize (H0c i env).
- destruct H0c; destruct H5.
- split; trivial.
- split; trivial.
- intros.
- cut_to H6; trivial.
- simpl; tauto.
- + case_eq (vartlookup init_env (s, DTfloat)); [intros|tauto].
- rewrite H3, H4 in H2; simpl in H2.
- cut_to H2; try tauto.
- unfold lift2 in H2; simpl in H2.
- inversion H2.
- unfold subvar in H6; simpl in H6.
- rewrite H3 in H1.
- specialize (H1 H).
- case_eq ( vartlookup d (s, DTfloat)); [intros|tauto].
- rewrite H5 in H6.
- f_equal; lra.
- congruence.
- Qed.
-
- Lemma list_env_iter_gen_delta {n} (s: SubVar) (init_env : df_env) (old : float)
- (f: {n' : nat | n' < n} -> df_env -> option df_env) (i0 : {n' : nat | n' < n}):
- let v := (s,DTfloat) in
- vartlookup init_env v = Some old ->
- (forall (i : {n' : nat | n' < n}) (env : df_env),
- (f i env) <> None /\
- (vartlookup env v <> None ->
- match f i env with
- | Some xenv => vartlookup xenv v <> None
- | _ => True
- end ) /\
- (forall (env2 : df_env),
- match vartlookup env v, vartlookup env2 v, f i env, f i env2 with
- | Some val1, Some val2, Some xenv, Some xenv2 =>
- subvar v xenv val1 = subvar v xenv2 val2
- | _, _, _, _ => True
- end) /\
- ( i <> i0 -> vartlookup env v <> None ->
- lift2 (fun e val => subvar v e val) (f i env)
- (vartlookup env v) = Some 0%R)) ->
- lift (fun e => subvar v e old) (f i0 init_env) =
- lift (fun e => subvar v e old)
- (list_env_iter f
- (Some init_env) (bounded_seq0 n)).
- Proof.
- simpl; intros.
- unfold bounded_seq0.
- destruct (bounded_seq_break_at 0 n i0) as [b [c [eqq1 [fa1 fa2]]]]; [lia |].
- rewrite eqq1.
- rewrite list_env_iter_app; simpl.
- match_option.
- -
- assert (eqq': subvar (s, DTfloat) d old = 0%R).
- + generalize (list_env_iter_gen_delta0 s init_env f b); intros.
- simpl in H1.
- cut_to H1; try congruence.
- * unfold lift2 in H1.
- rewrite H, eqq in H1.
- now inversion H1.
- * intros.
- specialize (H0 i env).
- destruct H0; destruct H2; destruct H3.
- split; trivial.
- split; trivial.
- intros.
- cut_to H4; trivial.
- rewrite Forall_forall in fa1.
- specialize (fa1 i H5).
- intro eq1; rewrite eq1 in fa1; lia.
- + generalize (vartlookup_list_env_iter s f b); intros vart.
- * specialize (vart init_env d eqq).
- assert (vartinit: vartlookup init_env (s, DTfloat) <> None) by congruence.
- specialize (vart vartinit).
- assert (f i0 d <> None) by apply H0.
- case_eq (f i0 d); [intros | tauto].
- generalize (list_env_iter_gen_delta0 s d0 f c); simpl; intros.
- cut_to H3; try congruence.
- -- unfold lift at 2.
- unfold lift2 in H3.
- match_option.
- ++ rewrite eqq0 in H3.
- match_option_in H3.
- rewrite (split_subvar d0 d1 old d2);trivial.
- inversion H3.
- rewrite H5.
- unfold lift.
- match_option.
- ** f_equal.
- case_eq (vartlookup d (s, DTfloat)); intros.
- --- rewrite (split_subvar d d0 old d4); trivial.
- rewrite eqq'.
- specialize (H0 i0 init_env).
- destruct H0; destruct H6; destruct H7.
- specialize (H7 d).
- rewrite H, H4, eqq3, H2 in H7.
- rewrite H7; lra.
- --- cut_to vart; [tauto|].
- intros.
- specialize (H0 i env).
- destruct H0; destruct H8.
- specialize (H8 H7).
- now rewrite H6 in H8.
- ** specialize (H0 i0 init_env).
- destruct H0; congruence.
- ++ generalize (list_env_iter_total_fun f d0 c); intros.
- cut_to H4; try congruence.
- apply H0.
- -- cut_to vart.
- ++ specialize (H0 i0 d).
- destruct H0; destruct H4.
- rewrite H2 in H4.
- apply H4; trivial.
- ++ intros.
- specialize (H0 i env).
- destruct H0;destruct H6.
- specialize (H6 H5).
- now rewrite H4 in H6.
- -- intros.
- specialize (H0 i env).
- split.
- apply H0.
- split.
- apply H0.
- intros.
- assert (i <> i0).
- ++ rewrite Forall_forall in fa2.
- specialize (fa2 i H4).
- intro eq1; rewrite eq1 in fa2; lia.
- ++ now apply H0.
- - generalize (list_env_iter_total_fun f init_env b); intros.
- cut_to H1; trivial.
- tauto.
- intros.
- apply H0.
- Qed.
-
- Lemma list_env_iter_vec_delta {n} (σ:df_env)
- (x:Vector (DefinedFunction UnitAnn DTfloat) n) (s: SubVar) grad_env
- (i0 : {n' : nat | n' < n}) (old : float) :
- let v := (s, DTfloat) in
- vartlookup grad_env v = Some old ->
- (forall (j: {n' : nat | n' < n}) ,
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over (x j) vl) ->
- lift (fun e => subvar v e old)
- (df_eval_backprop_deriv σ (x i0) grad_env 1%R) =
- lift (fun e => subvar v e old)
- (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (UnitVector n i0 i))
- (Some grad_env) (bounded_seq0 n)).
- Proof.
- simpl; intros.
- generalize (list_env_iter_gen_delta
- s grad_env old
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env (UnitVector n i0 i))
- i0).
- simpl; intros.
- specialize (H1 H).
- rewrite <- H1.
- - unfold UnitVector; simpl.
- now destruct (equiv_dec (` i0) (` i0)); [|congruence].
- - clear H1.
- intros.
- split; [|split].
- + apply backprop_deriv_fully_closed_not_none; auto.
- + intros.
- match_option.
- now apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in eqq;trivial.
- + split.
- * intros.
- do 4 match_option.
- generalize (backprop_indep_env σ (x i) s env env2
- (UnitVector n i0 i)); simpl; intros HH.
- cut_to HH; trivial; try congruence.
- unfold df_eval_backprop_delta in HH.
- rewrite eqq,eqq0,eqq1,eqq2 in HH.
- unfold lift in HH.
- now inversion HH.
- * intros.
- unfold UnitMatrix; simpl.
- destruct (equiv_dec (` i) (` i0)).
- elim H1; destruct i; destruct i0; simpl in *; red in e; subst; apply index_pf_irrel.
- unfold lift2.
- generalize (scalarMult_backprop_grad0 σ (x i) s env 0%R); simpl; intros.
- unfold df_eval_backprop_delta in H3.
- specialize (H3 H2).
- replace (0 * 0)%R with 0%R in H3 by lra.
- match_option.
- match_option; [|tauto].
- -- rewrite eqq0 in H3.
- replace (UnitVector n i0 i) with Fzero in eqq; simpl in eqq.
- ++ rewrite eqq in H3; unfold lift in H3.
- apply H3; congruence.
- ++ unfold UnitVector; simpl.
- destruct (equiv_dec (` i) (` i0)); [|trivial].
- elim H1; destruct i; destruct i0; simpl in *; red in e; subst; apply index_pf_irrel.
- -- assert (df_eval_backprop_deriv σ (x i) env (UnitVector n i0 i) <> None).
- now apply backprop_deriv_fully_closed_not_none.
- tauto.
- Qed.
-
-
- Lemma list_env_iter_backprop_indep_env_vec {m} (σ:df_env)
- (vecdf: Vector (DefinedFunction UnitAnn DTfloat) m)
- (s:SubVar) (env env2:df_env) (grad: Vector float m)
- (old1 old2: float) :
- let v := (s, DTfloat) in
- vartlookup env v = Some old1 ->
- vartlookup env2 v = Some old2 ->
- (let vl := map (fun ve => projT1 ve) σ in
- forall j : {m' : nat | m' < m},
- fully_closed_over (vecdf j) vl) ->
- lift (fun e => subvar (s, DTfloat) e old1)
- (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env0 (grad j))
- (Some env) (bounded_seq0 m)) =
- lift (fun e => subvar (s, DTfloat) e old2)
- (list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env0 (grad j))
- (Some env2) (bounded_seq0 m)).
- Proof.
- intros.
- subst v.
- revert old1 old2 env env2 H H0.
- induction (bounded_seq0 m).
- - intros.
- simpl.
- unfold subvar; simpl.
- rewrite H, H0.
- f_equal; lra.
- - intros.
- simpl.
- case_eq (df_eval_backprop_deriv σ (vecdf a) env (grad a)); intros.
- case_eq (df_eval_backprop_deriv σ (vecdf a) env2 (grad a)); intros.
- case_eq (vartlookup d (s, DTfloat)); intros.
- case_eq (vartlookup d0 (s, DTfloat)); intros.
- + specialize (IHl d1 d2 d d0 H4 H5).
- unfold lift.
- do 2 match_option.
- * rewrite eqq, eqq0 in IHl.
- unfold lift in IHl.
- inversion IHl.
- f_equal.
- rewrite (split_subvar d d3 old1 d1); trivial.
- rewrite (split_subvar d0 d4 old2 d2); trivial.
- rewrite H7.
- generalize (backprop_indep_env
- σ (vecdf a) s
- env env2 (grad a)); simpl; intros.
- cut_to H6; trivial; try congruence.
- unfold df_eval_backprop_delta in H6.
- rewrite H, H0, H2, H3 in H6.
- unfold lift in H6.
- inversion H6.
- rewrite H9; lra.
- * generalize (list_env_iter_total_fun
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env0 (grad j))
- d0 l); intros.
- cut_to H6; [tauto|].
- intros.
- apply backprop_deriv_fully_closed_not_none; auto.
- * generalize (list_env_iter_total_fun
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (vecdf j) env0 (grad j))
- d l); intros.
- cut_to H6; [tauto|].
- intros.
- apply backprop_deriv_fully_closed_not_none; auto.
- + apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in H3
- ;trivial; congruence.
- + apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in H2
- ;trivial; congruence.
- + assert (df_eval_backprop_deriv σ (vecdf a) env2 (grad a) <> None) by
- (apply backprop_deriv_fully_closed_not_none; auto); tauto.
- + assert (df_eval_backprop_deriv σ (vecdf a) env (grad a) <> None) by
- (apply backprop_deriv_fully_closed_not_none; auto); tauto.
- Qed.
-
- Lemma list_env_iter_mat_delta {n m} (σ:df_env)
- (x:Matrix (DefinedFunction UnitAnn DTfloat) n m) (s: SubVar) grad_env
- (i0 : {n' : nat | n' < n})
- (j0 : {m' : nat | m' < m}) (old : float) :
- let v := (s, DTfloat) in
- vartlookup grad_env v = Some old ->
- (forall (i: {n' : nat | n' < n}) (j: {m' : nat | m' < m}) ,
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over (x i j) vl) ->
- lift (fun e => subvar v e old)
- (df_eval_backprop_deriv σ (x i0 j0) grad_env 1%R) =
- lift (fun e => subvar v e old)
- (list_env_iter
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv
- σ (x i j) env0
- (UnitMatrix n m i0 j0 i j)) (Some env)
- (bounded_seq0 m)) (Some grad_env) (bounded_seq0 n)).
- Proof.
- simpl; intros.
- generalize (list_env_iter_gen_delta
- s grad_env old
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv
- σ (x i j) env0
- (UnitMatrix n m i0 j0 i j)) (Some env)
- (bounded_seq0 m)) i0); simpl.
- intros.
- specialize (H1 H).
- rewrite <- H1.
- - clear H1.
- generalize (list_env_iter_gen_delta
- s grad_env old
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i0 j) env0 (UnitMatrix n m i0 j0 i0 j))
- j0); simpl.
- intros.
- specialize (H1 H).
- rewrite <- H1.
- + unfold UnitMatrix; simpl.
- destruct (equiv_dec (` i0) (` i0)); [|congruence].
- now destruct (equiv_dec (` j0) (` j0)); [|congruence].
- + clear H1.
- intros.
- split; [|split].
- * apply backprop_deriv_fully_closed_not_none; auto.
- * intros.
- match_option.
- now apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in eqq;trivial.
- * split.
- -- intros.
- do 4 match_option.
- generalize (backprop_indep_env σ (x i0 i) s env env2
- (UnitMatrix n m i0 j0 i0 i)); simpl; intros HH.
- cut_to HH; trivial; try congruence.
- unfold df_eval_backprop_delta in HH.
- rewrite eqq,eqq0,eqq1,eqq2 in HH.
- unfold lift in HH.
- now inversion HH.
- -- intros.
- unfold UnitMatrix; simpl.
- destruct (equiv_dec (` i0) (` i0)); [|congruence].
- destruct (equiv_dec (` i) (` j0)).
- elim H1; destruct i; destruct j0; simpl in *; red in e0; subst; apply index_pf_irrel.
- unfold lift2.
- generalize (scalarMult_backprop_grad0 σ (x i0 i) s env 0%R); simpl; intros.
- unfold lift2.
- replace (0 * 0)%R with 0%R in H3 by lra.
- unfold df_eval_backprop_delta in H3.
- match_option.
- match_option.
- rewrite eqq0, eqq in H3; unfold lift in H3.
- cut_to H3; congruence.
- rewrite eqq0 in H3.
- apply H3; trivial.
- tauto.
- congruence.
- assert (df_eval_backprop_deriv σ (x i0 i) env 0%R <> None).
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- - split.
- generalize (list_env_iter_total_fun
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (UnitMatrix n m i0 j0 i j))
- env (bounded_seq0 m)); intros.
- apply H2.
- intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- split.
- + intros.
- match_option.
- apply (vartlookup_list_env_iter
- s (fun (j : {m' : nat | m' < m})
- (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (UnitMatrix n m i0 j0 i j))
- (bounded_seq0 m) env d); trivial.
- intros.
- apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in H3; trivial.
- + split.
- * intros.
- do 4 match_option.
- generalize (list_env_iter_backprop_indep_env_vec
- σ (x i) s env env2
- (UnitMatrix n m i0 j0 i)
- d d0); simpl; intros.
- specialize (H2 eqq eqq0).
- specialize (H2 (H0 i)).
- rewrite eqq1, eqq2 in H2.
- unfold lift in H2.
- now inversion H2.
- * intros.
- replace (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0 (UnitMatrix n m i0 j0 i j))
- with
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv σ (x i j) env0
- (scalarMult DTfloat 0%R 0%R)).
- case_eq (vartlookup env (s, DTfloat)); [intros | tauto].
- apply scalarMult_backprop_list_env_iter_grad0; trivial.
- apply functional_extensionality; intros.
- apply functional_extensionality; intros.
- f_equal.
- unfold scalarMult, UnitMatrix; simpl.
- destruct (equiv_dec (` i) (` i0)).
- red in e.
- elim H2; destruct i; destruct i0; simpl in *; subst.
- apply index_pf_irrel.
- lra.
- Qed.
-
- Lemma list_env_iter_matvec_delta {m n} (σ:df_env)
- (df2:DefinedFunction UnitAnn (DTVector m)) (s: SubVar) grad_env
- (i0 : {n' : nat | n' < n})
- (j0 : {m' : nat | m' < m}) (old : float) :
- vartlookup grad_env (s, DTfloat) = Some old ->
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df2 vl ->
- (lift (fun e => subvar (s, DTfloat) e old)
- (df_eval_backprop_deriv
- σ df2 grad_env
- (UnitVector m j0)) =
- (lift (fun e => subvar (s, DTfloat) e old)
- (list_env_iter
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv
- σ df2 env
- ((UnitMatrix n m i0 j0) i))
- (Some grad_env)
- (bounded_seq0 n)))).
- Proof.
- simpl; intros.
- generalize (list_env_iter_gen_delta
- s grad_env old
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv
- σ df2 env
- ((UnitMatrix n m i0 j0) i))
- i0).
- simpl; intros.
- specialize (H1 H).
- rewrite <- H1.
- - unfold UnitMatrix, UnitVector; simpl.
- now destruct (equiv_dec (` i0) (` i0)); [|congruence].
- - clear H1.
- intros.
- split; [|split].
- + apply backprop_deriv_fully_closed_not_none; auto.
- + intros.
- match_option.
- now apply df_eval_backprop_deriv_preserves_lookup_not_none with (xv := (s,DTfloat)) in eqq;trivial.
- + split.
- * intros.
- do 4 match_option.
- generalize (backprop_indep_env σ df2 s env env2
- (UnitMatrix n m i0 j0 i)); simpl; intros HH.
- cut_to HH; trivial; try congruence.
- unfold df_eval_backprop_delta in HH.
- rewrite eqq,eqq0,eqq1,eqq2 in HH.
- unfold lift in HH.
- now inversion HH.
- * intros.
- unfold UnitMatrix; simpl.
- destruct (equiv_dec (` i) (` i0)).
- elim H1; destruct i; destruct i0; simpl in *; red in e; subst; apply index_pf_irrel.
- unfold lift2.
- generalize (scalarMult_backprop_grad0 σ df2 s env (ConstVector m 0%R)); simpl; intros.
- unfold df_eval_backprop_delta in H3.
- specialize (H3 H2).
- match_option.
- match_option; [|tauto].
- -- rewrite eqq0 in H3.
- replace (fun i : {n' : nat | n' < m} => (0 * ConstVector m 0 i)%R) with
- (fun i : {n' : nat | n' < m} => 0%R) in H3.
- ++ rewrite eqq in H3; unfold lift in H3.
- apply H3; congruence.
- ++ apply functional_extensionality; intros.
- unfold ConstVector; simpl.
- lra.
- -- assert (df_eval_backprop_deriv σ df2 env (fun _ => 0%R) <> None).
- apply backprop_deriv_fully_closed_not_none; trivial.
- tauto.
- Qed.
-
- Lemma vmap_eta {A B} {n} (f:A->B) (d:Vector A n) : vmap f d = vmap (fun x => f x) d.
- Proof.
- now apply vmap_ext.
- Qed.
-
- Lemma vsum_eta {n} (d:Vector float n) : vsum d = vsum (fun i => d i).
- Proof.
- now apply vsum_ext.
- Qed.
-
- Lemma msum_unitvector m n x d1 d0 :
- msum
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (UnitVector m x i * d1 i j * d0 j)%R) =
- vsum
- (fun j : {n' : nat | n' < n} =>
- (d1 x j * d0 j)%R).
- Proof.
- unfold msum.
-
- transitivity (
- vsum
- (fun i : {n' : nat | n' < m} =>
- vsum
- (fun (j : {m' : nat | m' < n}) => (UnitVector m x i * d1 i j * d0 j)%R))).
- { apply vsum_ext; intros ?.
- now rewrite vmap_nth.
- }
- transitivity (
- vsum
- (fun i : {n' : nat | n' < m} =>
- ((UnitVector m x i * vsum
- (fun (j : {m' : nat | m' < n}%nat) => d1 i j * d0 j))%R))).
- {
- apply vsum_ext; intros ?.
- rewrite vsum_mult.
- apply vsum_ext; intros ?.
- lra.
- }
- transitivity (
- vsum
- (fun i : {n' : nat | n' < m} =>
- (vsum (fun j : {m' : nat | (m' < n)%nat} => d1 i j * d0 j) * UnitVector m x i)%R)).
- {
- apply vsum_ext; intros ?; lra.
- }
- now rewrite vsum_unitvector.
- Qed.
-
- Lemma vsum_as_sum {n} (v:Vector float n) : vsum v = fold_right Rplus R0 (vector_to_list v).
- Proof.
- rewrite vsum_alt_eq.
- unfold vector_to_list, vector_fold_right.
- induction n.
- - rewrite vector_fold_right_dep_0; trivial.
- - repeat rewrite vector_fold_right_dep_Sn.
- now rewrite IHn.
- Qed.
-
- Lemma msum_as_sum {m n} (mat:Matrix float m n) : msum mat = fold_right Rplus R0 (matrix_to_list mat).
- Proof.
- unfold msum, matrix_to_list, matrix_to_list_list.
- transitivity
- (vsum (fun i => fold_right Rplus R0 (vector_to_list (mat i)))).
- { apply vsum_ext; intros [??].
- now rewrite vmap_nth, vsum_as_sum.
- }
- rewrite vsum_as_sum.
- rewrite fold_right_plus_concat.
- rewrite map_vector_to_list_vmap.
- f_equal.
- apply vector_to_list_ext.
- intros [??].
- now rewrite vmap_nth.
- Qed.
-
- Lemma transpose_perm_bounded {m n : nat} (mat : Matrix float m n) bound_m pf_m bound_n pf_n:
- Permutation
- (concat
- (vector_fold_right_bounded_dep (fun _ : nat => Datatypes.cons) []
- (fun i : {n' : nat | n' < m} =>
- vector_fold_right_bounded_dep (fun _ : nat => Datatypes.cons) [] (mat i) bound_n pf_n) bound_m pf_m))
- (concat
- (vector_fold_right_bounded_dep (fun _ : nat => Datatypes.cons) []
- (fun i : {n' : nat | n' < n} =>
- vector_fold_right_bounded_dep (fun _ : nat => Datatypes.cons) [] (transpose mat i) bound_m pf_m) bound_n pf_n)).
- Proof.
- Hint Constructors Permutation.
- revert bound_n pf_n.
- induction bound_m; intros; simpl.
- - induction bound_n; simpl; trivial.
- - rewrite IHbound_m.
- clear IHbound_m.
- induction bound_n; simpl; trivial.
- rewrite <- IHbound_n.
- clear IHbound_n.
- apply Permutation_cons; trivial.
- unfold transpose; simpl.
- repeat rewrite <- app_ass.
- apply Permutation_app; trivial.
- apply Permutation_app_comm.
- Qed.
-
- Lemma transpose_perm {m n} (mat : Matrix float m n) :
- Permutation (matrix_to_list mat) (matrix_to_list (transpose mat)).
- Proof.
- apply transpose_perm_bounded.
- Qed.
-
- Lemma msum_transpose {m n} (mat : Matrix float m n) :
- msum mat = msum (transpose mat).
- Proof.
- repeat rewrite msum_as_sum.
- apply fold_right_perm; intros; try lra.
- apply transpose_perm.
- Qed.
-
- Ltac match_nested_case :=
- match goal with
- | [|- context[match match ?x with _ => _ end with _ => _ end]] =>
- let eqq := fresh "eqq" in
- case_eq x
- ; [intros ? eqq | intros eqq]
- end.
-
- Ltac match_nested_case_in H :=
- match H with
- | context[match match ?x with _ => _ end with _ => _ end] =>
- let eqq := fresh "eqq" in
- case_eq x
- ; [intros ? eqq | intros eqq]
- ; rewrite eqq in H
- end.
-
- Theorem df_eval_deriv_genvar_same (σ:df_env) (df:DefinedFunction UnitAnn DTfloat) (v:SubVar) :
- let vl := map (fun ve => projT1 ve) σ in
- is_scalar_function df ->
- fully_closed_over df vl ->
- let forward := df_eval_deriv_gen_top σ df (v, DTfloat) in
- lift transpose_lifted_type forward = df_eval σ (df_deriv df (v,DTfloat)).
- Proof.
- simpl.
- intros is_scalar.
- generalize is_scalar.
- pattern df.
- revert df is_scalar.
- DefinedFunction_scalar_cases (apply is_scalar_function_ind) Case; simpl; trivial; intros
- ; try
- ( cut_to H; trivial
- ;rewrite <- H
- ;assert (df_eval σ e <> None) by (apply eval_fully_closed_not_none; trivial)
- ;unfold lift, df_eval_deriv_gen_top; simpl
- ;match_nested_case; [|tauto]
- ;now match_nested_case).
-
- - Case "Var"%string.
- match_nested_case.
- + red in e.
- inversion e; subst.
- now refl_simpler; simpl.
- + now intros.
- - Case "Plus"%string.
- destruct H1.
- destruct is_scalar.
- cut_to H; trivial.
- cut_to H0; trivial.
- unfold lift, df_eval_deriv_gen_top in *; simpl in *.
- match_option_in H.
- + match_option_in H0.
- * rewrite <- H.
- now rewrite <- H0.
- * assert ( df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- - Case "Minus"%string.
- destruct H1.
- destruct is_scalar.
- cut_to H; trivial.
- cut_to H0; trivial.
- unfold lift, df_eval_deriv_gen_top in *; simpl in *.
- match_option_in H.
- + match_option_in H0.
- * rewrite <- H.
- now rewrite <- H0.
- * assert ( df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- - Case "Times"%string.
- destruct H1.
- destruct is_scalar.
- cut_to H; trivial.
- cut_to H0; trivial.
- unfold lift, df_eval_deriv_gen_top in *; simpl in *.
- match_nested_case; trivial.
- match_option_in H.
- + match_nested_case.
- * match_option_in H0.
- -- rewrite <- H.
- now rewrite <- H0.
- -- assert ( df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- * assert (df_eval σ r <> None)
- ; [apply eval_fully_closed_not_none; trivial | tauto].
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- - Case "Divide"%string.
- destruct H1.
- destruct is_scalar.
- cut_to H; trivial.
- cut_to H0; trivial.
- unfold lift, df_eval_deriv_gen_top in *; simpl in *.
- match_nested_case; trivial.
- match_option_in H.
- + match_nested_case.
- * match_option_in H0.
- -- rewrite <- H.
- now rewrite <- H0.
- -- assert ( df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- * assert (df_eval σ r <> None)
- ; [apply eval_fully_closed_not_none; trivial | tauto].
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- + assert (df_eval σ l <> None)
- ; [apply eval_fully_closed_not_none; trivial | tauto].
- - Case "Sign"%string.
- unfold lift, df_eval_deriv_gen_top in *; simpl in *.
- match_nested_case; [trivial|].
- assert (df_eval_deriv_genvar σ e [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- - Case "PSign"%string.
- unfold lift, df_eval_deriv_gen_top in *; simpl in *.
- match_nested_case; [trivial|].
- assert (df_eval_deriv_genvar σ e [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- - Case "Max"%string.
- destruct is_scalar; destruct H1.
- cut_to H; trivial.
- cut_to H0; trivial.
- unfold lift, df_eval_deriv_gen_top in *; simpl in *.
- assert (df_eval σ l <> None) by (apply eval_fully_closed_not_none; trivial).
- assert (df_eval σ r <> None) by (apply eval_fully_closed_not_none; trivial).
- match_nested_case; [|tauto].
- match_nested_case; [|tauto].
- match_option_in H.
- match_option_in H0.
- + rewrite <- H.
- rewrite <- H0.
- unfold pos_sign; simpl.
- case_eq (Rle_dec d d0); intros; f_equal.
- destruct (Rge_dec (d0 - d) 0); lra.
- destruct (Rge_dec (d0 - d) 0); lra.
- + assert (df_eval_deriv_genvar σ r [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (v, DTfloat) 1%R] <> None); [|tauto].
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- Qed.
-
- Lemma yay {T} (σ:df_env) (df:DefinedFunction UnitAnn T) (hs:has_scalar_functions df) (s: SubVar) grad_env :
- let v := (s, DTfloat) in
- vartlookup grad_env v <> None ->
- let vl := map (fun ve => projT1 ve) σ in
- fully_closed_over df vl ->
- let forward := df_eval_deriv_gen_top σ df v in
- let backward := df_eval_backward_gen_top σ df v grad_env in
- lift transpose_lifted_type forward = backward.
- Proof.
- simpl.
- intros vin closed.
- revert grad_env vin closed.
- unfold df_eval_deriv_gen_top, df_eval_backward_gen_top.
- pattern T, df.
- revert T df hs.
- DefinedFunction_cases (apply DefinedFunction_ind_unit_has_scalar_functions) Case
- ; simpl; intros.
- - Case "Number"%string.
- unfold subvar; simpl.
- match_destr; [ | tauto].
- f_equal; lra.
- - Case "Constant"%string.
- unfold subvar; simpl.
- match_destr; simpl.
- + match_destr; [ | tauto].
- f_equal; lra.
- + match_destr; [ | tauto].
- erewrite vectoro_to_ovector_forall_some_b_strong
- ; simpl; trivial; intros.
- unfold ConstVector.
- f_equal; lra.
- + match_destr; [ | tauto].
- unfold matrixo_to_omatrix.
- repeat (erewrite vectoro_to_ovector_forall_some_b_strong
- ; simpl; trivial; intros).
- unfold ConstMatrix.
- f_equal; lra.
- - Case "DVector"%string.
- unfold lift, two_vector_env_iter_alt.
- case_eq (vartlookup grad_env (s, DTfloat)); [intros|tauto].
- match_option.
- + specialize (apply vectoro_to_ovector_forall_some_f eqq); intros.
- symmetry.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- unfold snd in H.
- specialize (H i grad_env); simpl in H.
- rewrite vforall_forall in closed.
- assert (closedb := closed).
- specialize (closed i).
- specialize (H vin closed); simpl in H.
- rewrite H0 in H.
- specialize (H1 i); simpl in H1.
- rewrite H1 in H; simpl in H.
- unfold lift in H.
- match_option_in H.
- specialize (apply vectoro_to_ovector_forall_some_f eqq); intros.
- specialize (H2 i); simpl in H2.
- destruct i; simpl.
- unfold UnitVector; simpl.
- generalize (list_env_iter_total_fun
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- df_eval_backprop_deriv σ (x i) env
- (if equiv_dec (` i) x0 then 1%R else 0%R))
- grad_env (bounded_seq0 n)); intros.
- cut_to H3.
- match_option; [|tauto].
- * rewrite H.
- generalize (list_env_iter_vec_delta σ x s grad_env
- (exist (fun n' : nat => n' < n) x0 l) d).
- intros.
- simpl in H4.
- specialize (H4 H0 closedb).
- rewrite eqq0 in H4.
- unfold UnitVector in H4; simpl in H4.
- rewrite eqq1 in H4.
- unfold lift in H4.
- symmetry; trivial.
- * intros.
- apply backprop_deriv_fully_closed_not_none.
- apply closedb.
- + specialize (vectoro_to_ovector_exists_None eqq); intros.
- destruct H1.
- generalize (eval_deriv_genvar_fully_closed_not_none σ (x x0) [mk_env_entry (s, DTfloat) 1%R]); intros.
- rewrite vforall_forall in closed.
- now specialize (H1 (closed x0)).
- - Case "DMatrix"%string.
- unfold lift, matrixo_to_omatrix.
- case_eq (vartlookup grad_env (s, DTfloat)); [intros|tauto].
- match_option.
- + specialize (apply vectoro_to_ovector_forall_some_f eqq); intros.
- symmetry.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H1 i); simpl in H1.
- specialize (apply vectoro_to_ovector_forall_some_f H1); intros.
- specialize (H2 i0); simpl in H2.
- unfold two_matrix_env_iter_alt.
- specialize (H i i0 grad_env); simpl in H.
- rewrite vforall_forall in closed.
- assert (closedb := closed).
- specialize (closed i).
- rewrite vforall_forall in closed.
- specialize (closed i0).
- specialize (H vin closed); simpl in H.
- rewrite H2 in H; simpl in H.
- rewrite H0 in H; simpl in H.
- unfold lift in H.
- match_option_in H.
- rewrite H.
- destruct i.
- destruct i0.
- unfold lift; simpl.
- generalize (list_env_iter_total_fun
- (fun (i : {n' : nat | n' < n}) (env : df_env) =>
- list_env_iter
- (fun (j : {m' : nat | m' < m}) (env0 : df_env) =>
- df_eval_backprop_deriv
- σ (x i j) env0
- (UnitMatrix n m (exist (fun n' : nat => n' < n) x0 l)
- (exist (fun n' : nat => n' < m) x1 l0) i j))
- (Some env)
- (bounded_seq0 m))
- grad_env (bounded_seq0 n)); intros.
- cut_to H3.
- match_option; [|tauto].
- * generalize (list_env_iter_mat_delta σ x s grad_env
- (exist (fun n' : nat => n' < n) x0 l)
- (exist (fun n' : nat => n' < m) x1 l0) d).
- intros.
- simpl in H4.
- specialize (H4 H0).
- cut_to H4.
- rewrite eqq0 in H4.
- rewrite eqq1 in H4.
- unfold lift in H4.
- symmetry; trivial.
- intros.
- specialize (closedb i).
- rewrite vforall_forall in closedb.
- apply closedb.
- * intros.
- apply list_env_iter_total_fun; intros.
- apply backprop_deriv_fully_closed_not_none.
- specialize (closedb a).
- rewrite vforall_forall in closedb.
- apply closedb.
- + specialize (vectoro_to_ovector_exists_None eqq); intros.
- destruct H1.
- specialize (vectoro_to_ovector_exists_None e); intros.
- destruct H1.
- symmetry.
- generalize (eval_deriv_genvar_fully_closed_not_none σ (x x0 x1) [mk_env_entry (s, DTfloat) 1%R]); intros.
- rewrite vforall_forall in closed.
- specialize (closed x0).
- rewrite vforall_forall in closed.
- now specialize (H1 (closed x1)).
- - Case "Var"%string.
- unfold equiv_dec, vart_eqdec; simpl.
- destruct (vart_dec v (s, DTfloat)).
- + destruct v.
- inversion e.
- subst; simpl.
- refl_simpler; simpl.
- case_eq (vartlookup grad_env (s, DTfloat)); [intros |tauto].
- simpl; f_equal.
- symmetry.
- now apply subvar_addvar_scalar_eq.
- + case_eq (vartlookup grad_env (s, DTfloat)); [intros|tauto].
- destruct v.
- unfold snd.
- destruct d0.
- * simpl.
- unfold lift.
- destruct (vartlookup grad_env (s0, DTfloat)).
- -- f_equal; simpl; symmetry.
- now apply subvar_addvar_scalar_neq.
- -- f_equal; symmetry.
- unfold subvar; simpl.
- rewrite H; lra.
- * simpl; symmetry.
- apply vectoro_to_ovector_forall_some_b_strong; intros; unfold lift.
- destruct (vartlookup grad_env (s0, DTVector n0)).
- -- f_equal; simpl.
- unfold ConstVector.
- now apply subvar_addvar_scalar_neq.
- -- f_equal; unfold ConstVector.
- unfold subvar; simpl.
- rewrite H; lra.
- * simpl; symmetry.
- unfold matrixo_to_omatrix.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- unfold lift.
- destruct (vartlookup grad_env (s0, DTMatrix m n0)).
- -- f_equal; unfold ConstMatrix.
- now apply subvar_addvar_scalar_neq.
- -- f_equal; unfold ConstMatrix.
- unfold subvar; simpl.
- rewrite H; lra.
- - Case "Plus"%string.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H.
- case_eq (df_eval_backprop_deriv σ l grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- destruct closed.
- specialize (H H1).
- invcs H.
- { specialize (H0 d1).
- cut_to H0;
- [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))| apply H2].
- match_option
- ; rewrite eqq2 in H0
- ; simpl in *.
- - match_option_in H0.
- unfold lift in H0.
- match_option_in H0.
- invcs H0.
- simpl.
- f_equal.
- unfold subvar; simpl.
- rewrite eqq3.
- match_option; lra.
- - match_option_in H0; simpl.
- + unfold lift in *.
- match_option_in H0.
- + elim (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat) vin eqq3).
- }
- + destruct closed.
- specialize (H H1).
- congruence.
- + destruct closed.
- specialize (H H1).
- congruence.
- + destruct closed.
- specialize (H H1).
- case_eq (df_eval_backprop_deriv σ l grad_env 1%R)
- ; [intros ? eqq2 | intros eqq2].
- rewrite eqq2 in H.
- simpl in *.
- congruence.
- rewrite eqq2 in H.
- now apply H.
- - Case "Minus"%string.
- destruct closed.
- specialize (H grad_env vin H1).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H.
- case_eq (df_eval_backprop_deriv σ l grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- invcs H.
- { specialize (H0 d1).
- cut_to H0;
- [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- | apply H2].
- match_option
- ; rewrite eqq2 in H0
- ; simpl in *.
- - match_option_in H0.
- unfold lift in H0.
- match_option_in H0.
- invcs H0.
- simpl.
- f_equal.
- unfold subvar; simpl.
- rewrite eqq3.
- match_option.
- + case_eq (df_eval_backprop_deriv σ r d1 (- (1))%R); intros.
- * unfold lift.
- f_equal.
- match_option.
- -- generalize (scalarMult_backprop_grad_scalar σ r s d1 d1 1%R (-1)%R)
- ; intros; simpl in H0.
- cut_to H0.
- ++ unfold df_eval_backprop_delta in H0.
- rewrite eqq3 in H0; unfold lift in H0; simpl in H0.
- replace (-1 * 1)%R with (- (1))%R in H0 by lra.
- rewrite H, eqq4 in H0; inversion H0.
- unfold subvar in H0; simpl in H0.
- rewrite eqq6, eqq5 in H0.
- inversion H0.
- lra.
- ++ congruence.
- ++ congruence.
- ++ simpl; replace (-1 * 1)%R with (- (1))%R by lra; congruence.
- ++ congruence.
- -- generalize (df_eval_backprop_deriv_preserves_lookup_not_none H (s,DTfloat))
- ; intros.
- cut_to H0; congruence.
- * generalize (backprop_deriv_fully_closed_not_none σ r d1 (- (1))%R); intros.
- destruct H0; trivial.
- + generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq4 (s,DTfloat))
- ; intros.
- cut_to H; congruence.
- - match_option_in H0.
- + unfold lift.
- unfold lift in H0.
- match_option_in H0.
- match_option.
- specialize (df_eval_backprop_deriv_preserves_lookup_not_none eqq5).
- intros.
- specialize (H (s, DTfloat)).
- generalize (backprop_deriv_fully_closed_not_none σ r d1 (1%R)); intros.
- now destruct H3.
- + generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s,DTfloat))
- ; intros.
- cut_to H; congruence.
- }
- + unfold lift in H; simpl in H.
- match_option_in H.
- - Case "Times"%string.
- destruct closed.
- specialize (H grad_env vin H1).
- simpl in *.
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H3 H1).
- match_case; [intros|tauto].
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- generalize (eval_fully_closed_not_none σ r); intros.
- specialize (H5 H2).
- case_eq (df_eval σ r); [intros|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- case_eq (df_eval_backprop_deriv σ l grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- generalize (scalarMult_backprop_grad_scalar σ l s grad_env grad_env 1%R d1)
- ; intros; simpl in H7.
- unfold df_eval_backprop_delta in H7.
- rewrite eqq1 in H7; simpl in H7.
- specialize (H7 vin vin).
- rewrite eqq0 in H7.
- generalize (backprop_deriv_fully_closed_not_none σ l grad_env (d1 * 1)%R); intros.
- specialize (H8 H1); specialize (H7 H8).
- cut_to H7; try discriminate.
- invcs H.
- { specialize (H0 d3).
- cut_to H0;
- [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- | apply H2].
- unfold lift in H7; simpl in H7.
- match_option_in H7.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- ;intros.
- specialize (H vin).
- generalize (backprop_deriv_fully_closed_not_none σ r d3 1%R); intros.
- specialize (H9 H2).
- case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence].
- rewrite H10 in H0.
- unfold lift; simpl.
- unfold lift in H0.
- match_option_in H0.
- - generalize (scalarMult_backprop_grad_scalar σ r s d2 d3 1%R d)
- ; intros.
- unfold df_eval_backprop_delta in H11; simpl in H11.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat));intros.
- specialize (H12 vin).
- case_eq (vartlookup d2 (s, DTfloat)); [intros | congruence].
- specialize (H11 H12 H).
- rewrite H10, H13 in H11.
- generalize (backprop_deriv_fully_closed_not_none σ r d2 (d * 1)%R); intros.
- specialize (H14 H2); specialize (H11 H14 H9).
- unfold lift in H11.
- match_option_in H11; [|congruence]; f_equal.
- rewrite (split_subvar d2 d7 d0 d6); trivial.
- match_option_in H0; invcs H0.
- rewrite eqq5 in H11; invcs H11; invcs H7.
- lra.
- - now match_option_in H0.
- }
- unfold lift in H.
- match_case_in H; intros.
- rewrite H7 in H; congruence.
- generalize (backprop_deriv_fully_closed_not_none σ l grad_env 1%R); intros.
- now specialize (H8 H1).
- - Case "Divide"%string.
- destruct closed.
- specialize (H grad_env vin H1).
- simpl in *.
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H3 H1).
- match_case; [intros|tauto].
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- generalize (eval_fully_closed_not_none σ r); intros.
- specialize (H5 H2).
- case_eq (df_eval σ r); [intros|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- case_eq (df_eval_backprop_deriv σ l grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- generalize (scalarMult_backprop_grad_scalar σ l s grad_env grad_env 1%R (1 / d1)%R)
- ; intros; simpl in H7.
- unfold df_eval_backprop_delta in H7.
- rewrite eqq1 in H7; simpl in H7.
- specialize (H7 vin vin).
- rewrite eqq0 in H7.
- generalize (backprop_deriv_fully_closed_not_none σ l grad_env (1 / d1 * 1)%R); intros.
- specialize (H8 H1) ; specialize (H7 H8).
- cut_to H7; try discriminate.
- replace (1 / d1 * 1)%R with (1/d1)%R in H7 by lra.
- invcs H.
- { specialize (H0 d3).
- cut_to H0;
- [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- | apply H2].
- unfold lift in H7; simpl in H7.
- match_option_in H7.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- ;intros.
- specialize (H vin).
- generalize (backprop_deriv_fully_closed_not_none σ r d3 1%R); intros.
- specialize (H9 H2).
- case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence].
- rewrite H10 in H0.
- unfold lift; simpl.
- unfold lift in H0.
- match_option_in H0.
- - generalize (scalarMult_backprop_grad_scalar σ r s d2 d3 1%R (- d / (d1 * d1))%R)
- ; intros; simpl in H11.
- unfold df_eval_backprop_delta in H11; simpl in H11.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat));intros.
- specialize (H12 vin).
- case_eq (vartlookup d2 (s, DTfloat)); [intros | congruence].
- specialize (H11 H12 H).
- rewrite H10, H13 in H11.
- generalize (backprop_deriv_fully_closed_not_none σ r d2 (- d / (d1 * d1) * 1)%R); intros.
- specialize (H14 H2); specialize (H11 H14 H9).
- unfold lift in H11.
- match_option_in H11; [|congruence]; f_equal.
- rewrite (split_subvar d2 d7 d0 d6); trivial.
- match_option_in H0; invcs H0.
- rewrite eqq5 in H11; invcs H11; invcs H7.
- lra.
- - now match_option_in H0.
- }
- unfold lift in H.
- match_case_in H; intros.
- rewrite H7 in H; congruence.
- generalize (backprop_deriv_fully_closed_not_none σ l grad_env 1%R); intros.
- now specialize (H8 H1).
- - Case "Square"%string.
- specialize (H grad_env vin closed).
- simpl in *.
- generalize (eval_fully_closed_not_none σ e); intros.
- specialize (H0 closed).
- match_case; [intros|tauto].
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- + case_eq (df_eval_backprop_deriv σ e grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (2 * d)%R)
- ; intros; simpl in H2.
- unfold df_eval_backprop_delta in H2.
- rewrite eqq1 in H2; simpl in H2.
- specialize (H2 vin vin).
- rewrite eqq0 in H2.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env (2 * d * 1)%R); intros.
- specialize (H3 closed); specialize (H2 H3).
- cut_to H2; try discriminate.
- invcs H.
- unfold lift in H2; match_option_in H2.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- ;intros.
- unfold lift; now rewrite H2.
- + unfold lift in H.
- match_case_in H; intros.
- rewrite H2 in H; congruence.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros.
- now specialize (H3 closed).
- - Case "Exp"%string.
- specialize (H grad_env vin closed).
- simpl in *.
- generalize (eval_fully_closed_not_none σ e); intros.
- specialize (H0 closed).
- match_case; [intros|tauto].
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- + case_eq (df_eval_backprop_deriv σ e grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (exp d)%R)
- ; intros; simpl in H2.
- unfold df_eval_backprop_delta in H2.
- rewrite eqq1 in H2; simpl in H2.
- specialize (H2 vin vin).
- rewrite eqq0 in H2.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env (exp d * 1)%R); intros.
- specialize (H3 closed); specialize (H2 H3).
- cut_to H2; try discriminate.
- invcs H.
- replace (1 * exp d)%R with (exp d * 1)%R by lra.
- unfold lift in H2; match_option_in H2.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- ;intros.
- unfold lift; rewrite H2.
- f_equal; lra.
- + unfold lift in H.
- match_case_in H; intros.
- rewrite H2 in H; congruence.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros.
- now specialize (H3 closed).
- - Case "Log"%string.
- specialize (H grad_env vin closed).
- simpl in *.
- generalize (eval_fully_closed_not_none σ e); intros.
- specialize (H0 closed).
- match_case; [intros|tauto].
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- + case_eq (df_eval_backprop_deriv σ e grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (1 / d)%R)
- ; intros; simpl in H2.
- unfold df_eval_backprop_delta in H2.
- rewrite eqq1 in H2; simpl in H2.
- specialize (H2 vin vin).
- rewrite eqq0 in H2.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env (1 / d * 1)%R); intros.
- specialize (H3 closed); specialize (H2 H3).
- cut_to H2; try discriminate.
- invcs H.
- replace (1 / d)%R with (1 / d * 1)%R by lra.
- unfold lift in H2; match_option_in H2.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- ;intros.
- unfold lift; rewrite H2.
- f_equal; lra.
- + unfold lift in H.
- match_case_in H; intros.
- rewrite H2 in H; congruence.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros.
- now specialize (H3 closed).
- - Case "Abs"%string.
- specialize (H grad_env vin closed).
- simpl in *.
- generalize (eval_fully_closed_not_none σ e); intros.
- specialize (H0 closed).
- match_case; [intros|tauto].
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- + case_eq (df_eval_backprop_deriv σ e grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (sign d)%R)
- ; intros; simpl in H2.
- unfold df_eval_backprop_delta in H2.
- rewrite eqq1 in H2; simpl in H2.
- specialize (H2 vin vin).
- rewrite eqq0 in H2.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env (sign d * 1)%R); intros.
- specialize (H3 closed); specialize (H2 H3).
- cut_to H2; try discriminate.
- invcs H.
- replace (1 * (@sign floatish_R d))%R with ((@sign floatish_R d) * 1)%R by lra.
- unfold lift in H2; match_option_in H2.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- ;intros.
- unfold lift; rewrite H2.
- f_equal; lra.
- + unfold lift in H.
- match_case_in H; intros.
- rewrite H2 in H; congruence.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros.
- now specialize (H3 closed).
- - Case "Sign"%string.
- specialize (H grad_env vin closed).
- simpl in *.
- generalize (eval_fully_closed_not_none σ e); intros.
- specialize (H0 closed).
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- + case_eq (df_eval_backprop_deriv σ e grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (0)%R)
- ; intros H1; simpl in H1.
- unfold df_eval_backprop_delta in H1.
- rewrite eqq1 in H1; simpl in H1.
- specialize (H1 vin vin).
- rewrite eqq0 in H1.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env (0 * 1)%R); intros.
- specialize (H2 closed); specialize (H1 H2).
- cut_to H1; try discriminate.
- invcs H.
- replace (0)%R with (0 * 1)%R by lra.
- unfold lift in H1; match_option_in H1.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- ;intros.
- unfold lift; rewrite H1.
- f_equal; lra.
- + unfold lift in H.
- match_case_in H; intros.
- rewrite H1 in H; congruence.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros.
- now specialize (H2 closed).
- - Case "PSign"%string.
- specialize (H grad_env vin closed).
- simpl in *.
- generalize (eval_fully_closed_not_none σ e); intros.
- specialize (H0 closed).
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- + case_eq (df_eval_backprop_deriv σ e grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- generalize (scalarMult_backprop_grad_scalar σ e s grad_env grad_env 1%R (0)%R)
- ; intros H1; simpl in H1.
- unfold df_eval_backprop_delta in H1.
- rewrite eqq1 in H1; simpl in H1.
- specialize (H1 vin vin).
- rewrite eqq0 in H1.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env (0 * 1)%R); intros.
- specialize (H2 closed); specialize (H1 H2).
- cut_to H1; try discriminate.
- invcs H.
- replace (0)%R with (0 * 1)%R by lra.
- unfold lift in H1; match_option_in H1.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- ;intros.
- unfold lift; rewrite H1.
- f_equal; lra.
- + unfold lift in H.
- match_case_in H; intros.
- rewrite H1 in H; congruence.
- generalize (backprop_deriv_fully_closed_not_none σ e grad_env 1%R); intros.
- now specialize (H2 closed).
- - Case "Max"%string.
- destruct closed.
- specialize (H grad_env vin H1).
- specialize (H0 grad_env vin H2).
- simpl in *.
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H3 H1).
- match_case; [intros|tauto].
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- generalize (eval_fully_closed_not_none σ r); intros.
- specialize (H5 H2).
- case_eq (df_eval σ r); [intros|tauto].
- rewrite eqq0 in H.
- rewrite eqq0 in H0.
- destruct (Rle_dec d d1).
- + apply H0.
- + apply H.
- - Case "VectorDot"%string.
- destruct closed.
- specialize (H grad_env vin H1).
- simpl in *.
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H3 H1).
- match_case; [intros|tauto].
- generalize (eval_fully_closed_not_none σ r); intros.
- specialize (H5 H2).
- case_eq (df_eval σ r); [intros | tauto].
- replace (fun rv : R => (rv * 1)%R) with id.
- rewrite vmap_id; rewrite vmap_id.
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto].
- symmetry in H.
- generalize (backprop_deriv_fully_closed_not_none σ l grad_env d0); intros.
- specialize (H7 H1).
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq1 in H.
- specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H8.
- match_option.
- + match_option; [|tauto].
- generalize (backprop_deriv_fully_closed_not_none σ r d4 d); intros.
- specialize (H9 H2).
- unfold lift.
- match_option; [|tauto].
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat));intros.
- specialize (H10 vin).
- specialize (H0 d4 H10 H2).
- rewrite eqq0 in H0.
- match_option_in H0.
- symmetry in H0.
- unfold lift in H0.
- specialize (apply vectoro_to_ovector_forall_some_f H0); intros.
- simpl in H11.
- f_equal.
- unfold lift in H8.
- rewrite (split_subvar d4 d5 d1 d6); trivial.
- rewrite <- vsum_plus.
- generalize (df_eval_backprop_delta_by_unit_partvec σ l s grad_env d0
- (fun i => (d2 i * d0 i)%R)); intros.
- specialize (H12 H1 vin).
- generalize (df_eval_backprop_delta_by_unit_partvec σ r s d4 d
- (fun i => (d i * d3 i)%R)); intros.
- specialize (H13 H2 H10).
- cut_to H12.
- cut_to H13.
- * unfold df_eval_backprop_delta in H12.
- rewrite eqq1 in H12.
- unfold df_eval_backprop_delta in H13.
- rewrite eqq4 in H13.
- unfold lift in H12.
- unfold lift in H13.
- rewrite eqq2 in H12.
- rewrite eqq3 in H13.
- invcs H12; invcs H13.
- rewrite H14, H15; lra.
- * intros.
- replace (@scaleUnitVector (@float floatish_R) n i (d i) (IZR Z0)) with
- (scalarMult (DTVector n) (d i) (UnitVector n i)).
- rewrite scalarMult_backprop_grad_scalar with (grad_env1 := d4) (grad_env2:=d4);trivial.
- unfold df_eval_backprop_delta, lift.
- rewrite eqq4.
- specialize (H11 i).
- destruct i; simpl in H11.
- simpl_closed_backprop.
- f_equal.
- rewrite eqq5 in H11.
- invcs H11.
- rewrite H15.
- now unfold scalarMult; simpl.
- apply backprop_deriv_fully_closed_not_none; trivial.
- apply backprop_deriv_fully_closed_not_none; trivial.
- unfold scalarMult, UnitVector, scaleUnitVector; simpl.
- apply functional_extensionality; intros.
- destruct (equiv_dec (` x) (` i)); lra.
- * intros.
- replace (@scaleUnitVector (@float floatish_R) n i (d0 i) (IZR Z0)) with
- (scalarMult (DTVector n) (d0 i) (UnitVector n i)).
- rewrite scalarMult_backprop_grad_scalar with (grad_env1 := grad_env) (grad_env2:=grad_env);trivial.
- unfold df_eval_backprop_delta, lift.
- rewrite eqq1.
- specialize (H8 i).
- destruct i; simpl in H8.
- simpl_closed_backprop.
- f_equal.
- rewrite eqq5 in H8.
- invcs H8.
- rewrite H15.
- now unfold scalarMult; simpl; lra.
- apply backprop_deriv_fully_closed_not_none; trivial.
- apply backprop_deriv_fully_closed_not_none; trivial.
- unfold scalarMult, UnitVector, scaleUnitVector; simpl.
- apply functional_extensionality; intros.
- destruct (equiv_dec (` x) (` i)); lra.
- + generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros.
- now specialize (H9 H2).
- + generalize (eval_deriv_genvar_fully_closed_not_none σ l [mk_env_entry (s, DTfloat) 1%R]); intros.
- now specialize (H8 H1).
- + apply FunctionalExtensionality.functional_extensionality.
- intros; unfold id; lra.
- - Case "VectorSum"%string.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H; [|tauto|].
- + specialize (H closed).
- symmetry in H.
- generalize (vectoro_to_ovector_forall_some_f H); intros HH; clear H.
- replace (@lift (@df_env floatish_R) R
- (fun e : df_env => subvar (s, DTfloat) e d0)
- (df_eval_backprop_deriv σ v grad_env
- (ConstVector n 1%R)))
- with
- (df_eval_backprop_delta σ v (s, DTfloat) grad_env
- (ConstVector n 1%R)).
- rewrite df_eval_backprop_delta_by_unit_parts with (d1 := d); trivial.
- * intros.
- specialize (HH i).
- unfold df_eval_backprop_delta.
- rewrite eqq0.
- replace (UnitVector n i) with
- (coerce
- (df_eval_backward_gen_top_obligation_2 UnitAnn (DTVector n) v n eq_refl i)
- (UnitVector n i)); trivial.
- destruct i.
- now simpl.
- * unfold df_eval_backprop_delta.
- now rewrite eqq0.
- + specialize (H closed).
- symmetry in H.
- specialize (vectoro_to_ovector_exists_None H); intros.
- destruct H0.
- unfold lift in e.
- match_option_in e.
- generalize (backprop_deriv_fully_closed_not_none
- σ v grad_env
- (coerce
- (df_eval_backward_gen_top_obligation_2
- UnitAnn (DTVector n) v n eq_refl x)
- (UnitVector n x))); intros.
- specialize (H0 closed); tauto.
- - Case "MatrixSum"%string.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H; [|tauto|].
- + specialize (H closed).
- symmetry in H.
- specialize (apply vectoro_to_ovector_forall_some_f H); intros HH; simpl in HH.
- replace (@lift (@df_env floatish_R) R
- (fun e : df_env => subvar (s, DTfloat) e d0)
- (df_eval_backprop_deriv σ v grad_env
- (ConstMatrix m n 1%R)))
- with
- (df_eval_backprop_delta σ v (s, DTfloat) grad_env
- (ConstMatrix m n 1%R)).
- rewrite df_eval_backprop_delta_by_unit_parts_mat with (d1 := d); trivial.
- * intros.
- specialize (HH i).
- specialize (apply vectoro_to_ovector_forall_some_f HH); intros.
- specialize (H0 j); simpl in H0.
- unfold df_eval_backprop_delta.
- rewrite eqq0.
- replace (UnitMatrix m n i j) with
- (coerce
- (df_eval_backward_gen_top_obligation_3
- UnitAnn (DTMatrix m n) v m n eq_refl i j)
- (UnitMatrix m n i j)); trivial.
- destruct i; destruct j.
- now simpl.
- * unfold df_eval_backprop_delta.
- now rewrite eqq0.
- + specialize (H closed).
- symmetry in H.
- specialize (vectoro_to_ovector_exists_None H); intros.
- destruct H0.
- unfold lift in e.
- specialize (vectoro_to_ovector_exists_None e); intros.
- destruct H0.
- match_option_in e0.
- generalize (backprop_deriv_fully_closed_not_none
- σ v grad_env
- (coerce
- (df_eval_backward_gen_top_obligation_3
- UnitAnn (DTMatrix m n) v m n eq_refl x x0)
- (UnitMatrix m n x x0))); intros.
- specialize (H0 closed); tauto.
- - Case "VectorElem"%string.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H; [|tauto|].
- + specialize (H closed).
- symmetry in H.
- specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H.
- unfold lift.
- replace (fun k : {n' : nat | n' < n} => if equiv_dec (` k) (` i) then 1%R else 0%R)
- with
- (UnitVector n i).
- generalize (backprop_deriv_fully_closed_not_none
- σ l grad_env (UnitVector n i)); intros.
- specialize (H1 closed).
- specialize (H0 i).
- match_option; symmetry; [|tauto].
- destruct i; simpl in *.
- unfold lift in H0; f_equal.
- rewrite eqq1 in H0.
- now invcs H0.
- unfold UnitVector.
- apply FunctionalExtensionality.functional_extensionality; intros.
- trivial.
- + specialize (H closed).
- symmetry in H.
- specialize (vectoro_to_ovector_exists_None H); intros.
- destruct H0.
- unfold lift in e.
- match_option_in e.
- generalize (backprop_deriv_fully_closed_not_none
- σ l grad_env
- (coerce
- (df_eval_backward_gen_top_obligation_2
- UnitAnn (DTVector n) l n eq_refl x)
- (UnitVector n x))); intros.
- specialize (H0 closed); tauto.
- - Case "MatrixElem"%string.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H; [|tauto|].
- + specialize (H closed).
- symmetry in H.
- specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H.
- unfold lift.
- replace
- (fun (k1 : {n' : nat | n' < m}) (k2 : {m' : nat | m' < n}) =>
- if equiv_dec (` k1) (` i) then
- if equiv_dec (` k2) (` j) then 1%R else 0%R else 0%R)
- with (UnitMatrix m n i j).
- generalize (backprop_deriv_fully_closed_not_none
- σ l grad_env (UnitMatrix m n i j)); intros.
- specialize (H1 closed).
- specialize (H0 i).
- specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H2.
- specialize (H2 j).
- match_option; symmetry; [|tauto].
- destruct i; destruct j; simpl in *.
- unfold lift in H2; f_equal.
- rewrite eqq1 in H2.
- now invcs H2.
- unfold UnitMatrix.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- trivial.
- + specialize (H closed).
- symmetry in H.
- specialize (vectoro_to_ovector_exists_None H); intros.
- destruct H0.
- unfold lift in e.
- specialize (vectoro_to_ovector_exists_None e); intros.
- destruct H0.
- match_option_in e0.
- generalize (backprop_deriv_fully_closed_not_none
- σ l grad_env
- (coerce
- (df_eval_backward_gen_top_obligation_3
- UnitAnn (DTMatrix m n) l m n eq_refl x x0)
- (UnitMatrix m n x x0))); intros.
- specialize (H0 closed); tauto.
- - Case "MatrixVectorMult"%string.
- destruct closed.
- specialize (H grad_env vin H1).
- simpl in *.
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H3 H1).
- match_case; [intros|tauto].
- generalize (eval_fully_closed_not_none σ r); intros.
- specialize (H5 H2).
- case_eq (df_eval σ r); [intros | tauto].
- assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- match_option; [|tauto].
- assert (df_eval_deriv_genvar σ r [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- match_option; [|tauto].
- match_option; [|tauto].
- unfold lift.
- symmetry.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- simpl_closed_backprop.
- simpl_closed_backprop.
- f_equal.
- rewrite eqq1, eqq in H.
- unfold lift in H; symmetry in H.
- specialize (vectoro_to_ovector_forall_some_f H); intros.
- specialize (H0 env).
- simpler2.
- cut_to H0; try congruence.
- rewrite eqq0 in H0.
- rewrite (split_subvar env env0 d3 val); trivial.
- specialize (H9 i); simpl in H9.
- specialize (vectoro_to_ovector_forall_some_f H9); intros.
- replace (@vsum floatish_R n (fun j : {n' : nat | n' < n} => (d i j * d2 j + d1 i j * d0 j)%R))
- with ((vsum (fun j => (d i j * d2 j)%R)) + (vsum (fun j => (d1 i j * d0 j)%R)))%R
- ; [|rewrite vsum_plus; f_equal].
- unfold lift in H0; symmetry in H0.
- specialize (vectoro_to_ovector_forall_some_f H0); intros.
- simpl in H11; simpl in H10.
- destruct i; simpl in eqq2; simpl in eqq3.
- unfold matrix_vector_mult in eqq3; simpl in eqq3.
- unfold UnitVector in eqq2; simpl in eqq2.
- replace (fun i : {n' : nat | n' < n} =>
- (@vsum floatish_R m
- (fun j : {n' : nat | n' < m} =>
- (d j i * (@UnitVector floatish_R m (exist (fun n' : nat => (n' < m)%nat) x l0)
- j)%R)%R)))
- with (d (exist (fun n' : nat => (n' < m)%nat) x l0)) in eqq3.
- + generalize (df_eval_backprop_delta_by_unit_partvec
- σ r s env (d (exist (fun n' : nat => n' < m) x l0 ))
- (fun i => ((d (exist (fun n' : nat => (n' < m)%nat) x l0 ) i) * (d2 i))%R))
- ; intros.
- specialize (H12 H2).
- cut_to H12; try congruence.
- * unfold df_eval_backprop_delta in H12.
- rewrite eqq3, eqq4 in H12.
- unfold lift in H12.
- invcs H12.
- apply Rplus_eq_compat_l.
- generalize (df_eval_backprop_delta_by_unit_partmat
- σ l s grad_env
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- ((if equiv_dec (` i) x then 1 else 0) * d0 j)%R)
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- ((if equiv_dec (` i) x then 1 else 0) * (d1 i j) * (d0 j))%R))
- ; intros.
- specialize (H12 H1).
- cut_to H12; try congruence.
- -- unfold df_eval_backprop_delta in H12.
- rewrite eqq1 in H12.
- unfold lift in H12.
- rewrite eqq2 in H12.
- invcs H12.
- rewrite H15.
- replace
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- ((if equiv_dec (` i) x then 1 else 0) * d1 i j * d0 j)%R) with
- (fun (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}) =>
- (UnitVector m (exist _ x l0) i * d1 i j * d0 j)%R).
- now rewrite msum_unitvector.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- now unfold UnitVector; simpl.
- -- intros.
- rewrite scalarMult_backprop_grad_scalar with (grad_env2:=grad_env)
- ;trivial; try congruence.
- unfold lift.
- unfold df_eval_backprop_delta.
- rewrite eqq1.
- unfold lift.
- specialize (vectoro_to_ovector_forall_some_f H); intros.
- specialize (H13 i); simpl in H13.
- specialize (vectoro_to_ovector_forall_some_f H13); intros.
- specialize (H15 j); intros; simpl in H15.
- destruct i; destruct j; simpl in H15.
- match_option_in H15.
- f_equal; simpl.
- destruct (equiv_dec x0 x).
- invcs H15.
- rewrite H17; lra.
- lra.
- apply backprop_deriv_fully_closed_not_none; trivial.
- apply backprop_deriv_fully_closed_not_none; trivial.
- * intros.
- replace (@scaleUnitVector (@float floatish_R) n i (d (@exist nat (fun n' : nat => lt n' m) x l0) i) (IZR Z0)) with
- (scalarMult (DTVector n) (d (exist (fun n' : nat => n' < m) x l0) i) (UnitVector n i)).
- rewrite scalarMult_backprop_grad_scalar with (grad_env1 := env) (grad_env2:=env);trivial; try congruence.
- unfold scalarMult; simpl.
- unfold lift; simpl.
- specialize (H11 i).
- destruct i; simpl in H11.
- match_option_in H11.
- unfold df_eval_backprop_delta.
- rewrite eqq4,eqq5.
- unfold lift; f_equal.
- now invcs H11.
- apply backprop_deriv_fully_closed_not_none; trivial.
- apply backprop_deriv_fully_closed_not_none; trivial.
- unfold scalarMult, scaleUnitVector.
- apply FunctionalExtensionality.functional_extensionality; intros.
- unfold UnitVector.
- destruct (equiv_dec (` x0) (` i)); simpl; lra.
- + apply FunctionalExtensionality.functional_extensionality; intros.
- now rewrite vsum_unitvector.
- - Case "MatrixVectorAdd"%string.
- destruct closed.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H; [|tauto|].
- + specialize (H H1).
- symmetry in H.
- specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H3.
- match_option; symmetry.
- * apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H3 i).
- specialize (apply vectoro_to_ovector_forall_some_f H3); intros; simpl in H4.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H4 i0).
- unfold lift in H4; unfold lift.
- match_option_in H4.
- unfold lift in H0.
- rewrite eqq1 in H0.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat))
- ;intros.
- specialize (H5 vin).
- specialize (H0 d2 H5 H2).
- match_option_in H0.
- symmetry in H0.
- specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H6.
- specialize (H6 i).
- destruct i; destruct i0; simpl; simpl in eqq2; simpl in H6.
- rewrite eqq2.
- match_option_in H6.
- generalize (list_env_iter_total_fun
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv
- σ r env
- (transpose
- (UnitMatrix m n (exist (fun n' : nat => n' < m) x l0)
- (exist (fun n' : nat => n' < n) x0 l1)) i))
- d2 (bounded_seq0 n)); intros.
- cut_to H7.
- -- case_eq (list_env_iter
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv
- σ r env
- (@transpose R m n
- (UnitMatrix m n (exist (fun n' : nat => n' < m) x l0)
- (exist (fun n' : nat => n' < n) x0 l1)) i))
- (Some d2)
- (bounded_seq0 n)); [intros |tauto].
- f_equal.
- rewrite (split_subvar d2 d5 d0 d3); trivial.
- invcs H4.
- invcs H6.
- generalize (list_env_iter_matvec_delta
- σ r s d2
- (exist (fun n' : nat => n' < n) x0 l1)
- (exist (fun n' : nat => n' < m) x l0) d3).
- intros.
- specialize (H4 eqq3 H2).
- rewrite eqq4 in H4.
- replace (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv
- σ r env
- (UnitMatrix n m (exist (fun n' : nat => n' < n) x0 l1)
- (exist (fun n' : nat => n' < m) x l0) i)) with
- (fun (i : {m' : nat | m' < n}) (env : df_env) =>
- df_eval_backprop_deriv
- σ r env
- (@transpose R m n
- (UnitMatrix m n (exist (fun n' : nat => n' < m) x l0)
- (exist (fun n' : nat => n' < n) x0 l1)) i)) in H4.
- ++ rewrite H8 in H4.
- unfold lift in H4.
- invcs H4; lra.
- ++ apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- f_equal.
- unfold transpose, UnitMatrix.
- apply FunctionalExtensionality.functional_extensionality; intros.
- simpl.
- match_case; intros.
- match_case; intros.
- -- intros.
- apply backprop_deriv_fully_closed_not_none; trivial.
- * assert (df_eval_deriv_genvar σ r [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- - Case "MatrixMult"%string.
- destruct closed.
- specialize (H grad_env vin H1).
- simpl in *.
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H3 H1).
- match_case; [intros|tauto].
- generalize (eval_fully_closed_not_none σ r); intros.
- specialize (H5 H2).
- case_eq (df_eval σ r); [intros | tauto].
- assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- match_option; [|tauto].
- assert (df_eval_deriv_genvar σ r [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- match_option; [|tauto].
- match_option; [|tauto].
- unfold lift.
- symmetry.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- simpl_closed_backprop.
- simpl_closed_backprop.
- f_equal.
- rewrite eqq1, eqq in H.
- unfold lift in H; symmetry in H.
- specialize (vectoro_to_ovector_forall_some_f H); intros.
- specialize (H0 env).
- simpler2.
- cut_to H0; try congruence.
- rewrite eqq0 in H0.
- rewrite (split_subvar env env0 d3 val); trivial.
- specialize (H9 i); simpl in H9.
- specialize (vectoro_to_ovector_forall_some_f H9); intros.
- replace (@vsum floatish_R p (fun j => (d i j * d2 j i0 + d1 i j * d0 j i0)%R))
- with ((vsum (fun j => (d i j * d2 j i0)%R)) + (vsum (fun j => (d1 i j * d0 j i0)%R)))%R
- ; [|rewrite vsum_plus; f_equal].
- unfold lift in H0; symmetry in H0.
- specialize (vectoro_to_ovector_forall_some_f H0); intros.
- simpl in H10; simpl in H11.
- destruct i; destruct i0; simpl in eqq2; simpl in eqq3.
- unfold matrix_mult,UnitMatrix in eqq3; simpl in eqq3.
- unfold matrix_mult,UnitMatrix in eqq2; simpl in eqq2.
- generalize (df_eval_backprop_delta_by_unit_partmat
- σ l s grad_env
- (fun (i : {n' : nat | n' < m}) (k : {m' : nat | m' < p}) =>
- vsum
- (fun j : {n' : nat | n' < n} =>
- ((if equiv_dec (` i) x then
- if equiv_dec (` j) x0 then 1 else 0 else 0) * d0 k j)%R))
- (fun (i : {n' : nat | n' < m}) (k : {m' : nat | m' < p}) =>
- vsum
- (fun j : {n' : nat | n' < n} =>
- ((if equiv_dec (` i) x then
- if equiv_dec (` j) x0 then 1 else 0 else 0) *
- (d1 i k) * d0 k j)%R)))
- ; intros.
- specialize (H12 H1 vin).
- cut_to H12; try congruence.
- + unfold df_eval_backprop_delta in H12.
- rewrite eqq1 in H12.
- unfold lift in H12.
- rewrite eqq2 in H12.
- invcs H12.
- rewrite H14.
- assert (msum
- (fun (i : {n' : nat | (n' < m)%nat}) (k : {m' : nat | (m' < p)%nat}) =>
- vsum
- (fun j : {n' : nat | (n' < n)%nat} =>
- (if equiv_dec (` i) x then
- if equiv_dec (` j) x0 then 1 else 0 else 0)
- * d1 i k * d0 k j))%R =
- vsum
- (fun j : {n' : nat | (n' < p)%nat} =>
- d1 (exist (fun n' : nat => (n' < m)%nat) x l0) j *
- d0 j (exist (fun n' : nat => (n' < n)%nat) x0 l1))%R).
- * rewrite msum_transpose.
- unfold msum.
- f_equal.
- unfold transpose; simpl.
- replace (fun (i : {m' : nat | m' < p}) (j : {n' : nat | n' < m}) =>
- (@vsum floatish_R n
- (fun j0 : {n' : nat | n' < n} =>
- ((if equiv_dec (` j) x then
- if equiv_dec (` j0) x0 then 1 else 0
- else 0) * d1 j i * d0 i j0)%R))) with
- (fun (i : {m' : nat | m' < p}) (j : {n' : nat | n' < m}) =>
- if equiv_dec (` j) x then
- (d1 (exist _ x l0) i * d0 i (exist _ x0 l1))%R
- else 0%R).
- -- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vmap_nth.
- replace (fun j : {n' : nat | n' < m} =>
- if equiv_dec (` j) x
- then
- (d1 (exist (fun n' : nat => (n' < m)%nat) x l0) x1 *
- d0 x1 (exist (fun m' : nat => (m' < n)%nat) x0 l1))%R
- else 0%R) with
- (fun j : {n' : nat | n' < m} =>
- ((d1 (exist (fun n' : nat => (n' < m)%nat) x l0) x1 *
- d0 x1 (exist (fun m' : nat => (m' < n)%nat) x0 l1))%R *
- (@UnitVector floatish_R m (exist _ x l0) j))%R).
- now rewrite vsum_unitvector.
- apply FunctionalExtensionality.functional_extensionality; intros.
- unfold UnitVector; simpl.
- destruct (equiv_dec (` x2) x); lra.
- -- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- destruct ( equiv_dec (` x2) x).
- ++ replace
- (fun j0 : {n' : nat | n' < n} =>
- ((if equiv_dec (` j0) x0 then 1 else 0) * d1 x2 x1 * d0 x1 j0)%R)
- with
- (fun j0 : {n' : nat | n' < n} =>
- ((d1 x2 x1 * d0 x1 j0) * (UnitVector n (exist _ x0 l1) j0))%R).
- ** rewrite vsum_unitvector.
- red in e.
- subst.
- destruct x2.
- simpl.
- erewrite index_pf_irrel; eauto.
- ** apply FunctionalExtensionality.functional_extensionality; intros.
- unfold UnitVector; simpl.
- destruct (equiv_dec (` x3) x0); lra.
- ++ rewrite <- vsum_mult.
- lra.
- * rewrite H12.
- apply Rplus_eq_compat_r.
- generalize (df_eval_backprop_delta_by_unit_partmat
- σ r s env
- (fun (i : {n' : nat | n' < p}) (k : {m' : nat | m' < n}) =>
- vsum
- (fun j : {n' : nat | n' < m} =>
- (d j i *
- (if equiv_dec (` j) x then
- if equiv_dec (` k) x0 then 1 else 0 else 0))%R))
- (fun (i : {n' : nat | n' < p}) (k : {m' : nat | m' < n}) =>
- vsum
- (fun j : {n' : nat | n' < m} =>
- (d j i * d2 i k *
- (if equiv_dec (` j) x then
- if equiv_dec (` k) x0 then 1 else 0 else 0))%R)))
- ; intros.
- specialize (H13 H2).
- cut_to H13; try congruence.
- -- unfold df_eval_backprop_delta in H13.
- rewrite eqq4 in H13.
- unfold lift in H13.
- rewrite eqq3 in H13.
- invcs H13.
- rewrite H16.
- unfold msum; f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite vmap_nth.
- replace
- (fun k : {m' : nat | m' < n} =>
- (@vsum floatish_R m
- (fun j : {n' : nat | n' < m} =>
- (d j x1 * d2 x1 k *
- (if equiv_dec (` j) x then
- if equiv_dec (` k) x0 then 1 else 0 else 0))%R))) with
- (fun k : {m' : nat | m' < n} =>
- (vsum
- (fun j =>
- (d j x1 * d2 x1 k *
- (UnitVector m (exist _ x l0) j))%R)
- * (UnitVector n (exist _ x0 l1) k))%R).
- ++ rewrite vsum_unitvector.
- rewrite vsum_unitvector.
- lra.
- ++ apply FunctionalExtensionality.functional_extensionality; intros.
- rewrite Rmult_comm.
- rewrite vsum_mult.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- unfold UnitVector; simpl.
- destruct (equiv_dec (` x2) x0); destruct (equiv_dec (` x3) x); lra.
- -- intros.
- rewrite scalarMult_backprop_grad_scalar with (grad_env2:=env); trivial; try congruence.
- ++ unfold lift.
- unfold df_eval_backprop_delta.
- rewrite eqq4.
- unfold lift.
- specialize (H11 i); simpl in H11.
- specialize (vectoro_to_ovector_forall_some_f H11); intros.
- specialize (H15 j); intros; simpl in H15.
- destruct i; destruct j; simpl in H15.
- match_option_in H15.
- f_equal; simpl.
- invcs H15.
- rewrite H17.
- rewrite Rmult_comm.
- rewrite vsum_mult.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- ++ apply backprop_deriv_fully_closed_not_none; trivial.
- ++ apply backprop_deriv_fully_closed_not_none; trivial.
- + intros.
- rewrite scalarMult_backprop_grad_scalar with (grad_env2:=grad_env); trivial; try congruence.
- * unfold lift.
- unfold df_eval_backprop_delta.
- rewrite eqq1.
- unfold lift.
- specialize (vectoro_to_ovector_forall_some_f H i); intros; simpl in H11.
- specialize (vectoro_to_ovector_forall_some_f H13 j); intros; simpl in H14.
- destruct i; destruct j; simpl in H14.
- match_option_in H14.
- f_equal; simpl.
- invcs H14.
- rewrite H16.
- rewrite Rmult_comm.
- rewrite vsum_mult.
- f_equal.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- * apply backprop_deriv_fully_closed_not_none; trivial.
- * apply backprop_deriv_fully_closed_not_none; trivial.
- - Case "VectorPlus"%string.
- destruct closed.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H; [|tauto|].
- + specialize (H H1).
- symmetry in H.
- specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H3.
- match_option; symmetry.
- * apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H3 i).
- unfold lift in H3; unfold lift.
- match_option_in H3.
- unfold lift in H0.
- rewrite eqq1 in H0.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat))
- ;intros.
- specialize (H4 vin).
- specialize (H0 d2 H4 H2).
- match_option_in H0.
- symmetry in H0.
- specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H5.
- specialize (H5 i).
- destruct i; simpl in H1; simpl; simpl in eqq2; simpl in H5.
- rewrite eqq2.
- generalize (backprop_deriv_fully_closed_not_none
- σ r d2
- (UnitVector n (exist (fun n' : nat => n' < n) x l0)))
- ; intros; specialize (H6 H2).
- match_option; [|tauto]; f_equal.
- rewrite eqq4 in H5; inversion H5.
- rewrite (split_subvar d2 d4 d0 d3); trivial.
- inversion H3; lra.
- * generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros.
- specialize (H4 H2).
- tauto.
- + specialize (H H1).
- symmetry in H.
- specialize (apply vectoro_to_ovector_exists_None H); intros.
- destruct H3.
- unfold lift in e.
- match_option_in e.
- generalize (backprop_deriv_fully_closed_not_none
- σ l grad_env
- (coerce
- (df_eval_backward_gen_top_obligation_2
- UnitAnn (DTVector n) l n eq_refl x)
- (UnitVector n x))); intros.
- specialize (H3 H1); tauto.
- - Case "VectorMinus"%string.
- destruct closed.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H; [|tauto|].
- + specialize (H H1).
- symmetry in H.
- specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H3.
- match_option; symmetry.
- * apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H3 i).
- unfold lift in H3; unfold lift.
- match_option_in H3.
- unfold lift in H0.
- rewrite eqq1 in H0.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat))
- ;intros.
- specialize (H4 vin).
- specialize (H0 d2 H4 H2).
- match_option_in H0.
- symmetry in H0.
- specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H5.
- specialize (H5 i).
- destruct i; simpl in H0; simpl; simpl in eqq2; simpl in H5.
- rewrite eqq2.
- generalize (backprop_deriv_fully_closed_not_none
- σ r d2
- (UnitVector n (exist (fun n' : nat => n' < n) x l0)))
- ; intros; specialize (H6 H2).
- generalize (scalarMult_backprop_grad_scalar
- σ r s d2 d2
- (UnitVector n (exist (fun n' : nat => n' < n) x l0))
- (-1)%R ); intros.
- specialize (H7 H4 H4).
- generalize (backprop_deriv_fully_closed_not_none
- σ r d2
- (scalarMult (DTVector n) (-1)%R
- (UnitVector n (exist (fun n' : nat => n' < n) x l0))))
- ; intros; specialize (H8 H2).
- specialize (H7 H8 H6).
- unfold df_eval_backprop_delta in H7; simpl in H7.
- rewrite eqq3 in H7.
- unfold lift in H7; simpl in H7.
- replace (fun i : {n' : nat | n' < n} =>
- (- (@UnitVector floatish_R n (exist (fun n' : nat => (n' < n)%nat) x l0)) i)%R)
- with
- (fun i : {n' : nat | n' < n} =>
- (-1 * UnitVector n (exist (fun n' : nat => (n' < n)%nat) x l0) i)%R).
- match_option; [|tauto]; f_equal.
- rewrite eqq4 in H7.
- rewrite (split_subvar d2 d4 d0 d3); trivial.
- rewrite H5 in H7.
- inversion H7.
- rewrite H10.
- inversion H3; lra.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- * generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros.
- specialize (H4 H2).
- tauto.
- + specialize (H H1).
- symmetry in H.
- specialize (apply vectoro_to_ovector_exists_None H); intros.
- destruct H3.
- unfold lift in e.
- match_option_in e.
- generalize (backprop_deriv_fully_closed_not_none
- σ l grad_env
- (coerce
- (df_eval_backward_gen_top_obligation_2 UnitAnn (DTVector n) l n eq_refl x)
- (UnitVector n x))); intros.
- specialize (H3 H1); tauto.
- - Case "MatrixPlus"%string.
- destruct closed.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H; [|tauto|].
- + specialize (H H1).
- symmetry in H.
- specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H1.
- match_option; symmetry.
- * apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H3 i).
- specialize (apply vectoro_to_ovector_forall_some_f H3); intros; simpl in H4.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H4 i0).
- unfold lift in H4; unfold lift.
- match_option_in H4.
- unfold lift in H0.
- rewrite eqq1 in H0.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat))
- ;intros.
- specialize (H5 vin).
- specialize (H0 d2 H5 H2).
- match_option_in H0.
- symmetry in H0.
- specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H6.
- specialize (H6 i).
- specialize (apply vectoro_to_ovector_forall_some_f H6); intros; simpl in H7.
- specialize (H7 i0).
- destruct i; destruct i0; simpl; simpl in eqq2; simpl in H7.
- rewrite eqq2.
- generalize (backprop_deriv_fully_closed_not_none
- σ r d2
- (UnitMatrix n m (exist (fun n' : nat => n' < n) x l0)
- (exist (fun n' : nat => n' < m) x0 l1)))
- ; intros; specialize (H8 H2).
- match_option; [|tauto]; f_equal.
- rewrite eqq4 in H7; inversion H7.
- rewrite (split_subvar d2 d4 d0 d3); trivial.
- inversion H4; lra.
- * generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros.
- now specialize (H4 H2).
- + specialize (H H1).
- symmetry in H.
- specialize (apply vectoro_to_ovector_exists_None H); intros.
- destruct H3.
- specialize (apply vectoro_to_ovector_exists_None e); intros.
- destruct H3.
- unfold lift in e0.
- match_option_in e0.
- generalize (backprop_deriv_fully_closed_not_none
- σ l grad_env
- (coerce
- (df_eval_backward_gen_top_obligation_3 UnitAnn (DTMatrix n m) l n m eq_refl x
- x0) (UnitMatrix n m x x0))); intros.
- specialize (H3 H1); tauto.
- - Case "MatrixMinus"%string.
- destruct closed.
- specialize (H grad_env vin).
- simpl in *.
- match_option
- ; rewrite eqq in H
- ; simpl in *; match_option_in H; [|tauto|].
- + specialize (H H1).
- symmetry in H.
- specialize (apply vectoro_to_ovector_forall_some_f H); intros; simpl in H1.
- match_option; symmetry.
- * apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H3 i).
- specialize (apply vectoro_to_ovector_forall_some_f H3); intros; simpl in H4.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H4 i0).
- unfold lift in H4; unfold lift.
- match_option_in H4.
- unfold lift in H0.
- rewrite eqq1 in H0.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq2 (s, DTfloat))
- ;intros.
- specialize (H5 vin).
- specialize (H0 d2 H5 H2).
- match_option_in H0.
- symmetry in H0.
- specialize (apply vectoro_to_ovector_forall_some_f H0); intros; simpl in H6.
- specialize (H6 i).
- specialize (apply vectoro_to_ovector_forall_some_f H6); intros; simpl in H7.
- specialize (H7 i0).
- destruct i; destruct i0; simpl; simpl in eqq2; simpl in H7.
- rewrite eqq2.
- generalize (backprop_deriv_fully_closed_not_none
- σ r d2
- (UnitMatrix n m (exist (fun n' : nat => n' < n) x l0)
- (exist (fun n' : nat => n' < m) x0 l1)))
- ; intros; specialize (H8 H2).
-
- generalize (scalarMult_backprop_grad_scalar
- σ r s d2 d2
- (UnitMatrix n m (exist (fun n' : nat => n' < n) x l0)
- (exist (fun n' : nat => n' < m) x0 l1))
- (-1)%R ); intros.
- specialize (H9 H5 H5).
- generalize (backprop_deriv_fully_closed_not_none
- σ r d2
- (scalarMult (DTMatrix n m) (-1)%R
- (UnitMatrix n m (exist (fun n' : nat => n' < n) x l0)
- (exist (fun n' : nat => n' < m) x0 l1))))
- ; intros; specialize (H10 H2).
- specialize (H9 H10 H8).
- unfold df_eval_backprop_delta in H9; simpl in H9.
- rewrite eqq3 in H9.
- unfold lift in H9; simpl in H9.
- replace
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (-
- (@UnitMatrix floatish_R n m (exist (fun n' : nat => (n' < n)%nat) x l0)
- (exist (fun n' : nat => (n' < m)%nat) x0 l1)) i j)%R)
- with
- (fun (i : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (-1 *
- UnitMatrix n m (exist (fun n' : nat => (n' < n)%nat) x l0)
- (exist (fun n' : nat => (n' < m)%nat) x0 l1) i j)%R).
- match_option; [|tauto]; f_equal.
- rewrite eqq4 in H9.
- rewrite (split_subvar d2 d4 d0 d3); trivial.
- rewrite H7 in H9.
- inversion H9.
- rewrite H12.
- inversion H4; lra.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- lra.
- * generalize (eval_deriv_genvar_fully_closed_not_none σ r [mk_env_entry (s, DTfloat) 1%R]); intros.
- now specialize (H4 H2).
- + specialize (H H1).
- symmetry in H.
- specialize (apply vectoro_to_ovector_exists_None H); intros.
- destruct H3.
- specialize (apply vectoro_to_ovector_exists_None e); intros.
- destruct H3.
- unfold lift in e0.
- match_option_in e0.
- generalize (backprop_deriv_fully_closed_not_none
- σ l grad_env
- (coerce
- (df_eval_backward_gen_top_obligation_3 UnitAnn (DTMatrix n m) l n m eq_refl x
- x0) (UnitMatrix n m x x0))); intros.
- specialize (H3 H1); tauto.
- - Case "VectorScalMult"%string.
- destruct closed.
- specialize (H grad_env vin H1).
- simpl in *.
- generalize (eval_fully_closed_not_none σ x); intros.
- specialize (H3 H1).
- match_case; [intros|tauto].
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H5 H2).
- case_eq (df_eval σ l); [intros|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- case_eq (df_eval_backprop_deriv σ x grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- match_option; unfold lift; symmetry.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- replace (@vsum
- floatish_R n
- (fun j : @sig nat (fun n' : nat => lt n' n) =>
- Rmult (d1 j)
- (@coerce (Vector R n) (Vector R n)
- (@df_eval_backward_gen_top_obligation_2
- floatish_R UnitAnn
- (DTVector n)
- (@VectorScalMult floatish_R UnitAnn n ann x l) n
- (@eq_refl definition_function_types (DTVector n)) i)
- (@UnitVector floatish_R n i) j)))
- with (d1 i).
- generalize (scalarMult_backprop_grad_scalar
- σ x s grad_env grad_env 1%R (d1 i)); intros; simpl in H7.
- unfold df_eval_backprop_delta in H7.
- rewrite eqq1 in H7; simpl in H7.
- specialize (H7 vin vin).
- rewrite eqq0 in H7.
- generalize (backprop_deriv_fully_closed_not_none
- σ x grad_env (d1 i * 1)%R ); intros.
- specialize (H8 H1); specialize (H7 H8).
- cut_to H7; try discriminate.
- invcs H.
- { specialize (H0 d3).
- cut_to H0;
- [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- | apply H2].
- unfold lift in H7; simpl in H7.
- match_option_in H7.
- replace (d1 i) with (d1 i * 1)%R by lra.
- rewrite eqq3.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)); intros.
- specialize (H vin).
- case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence].
- rewrite H9 in H0.
- unfold lift; simpl.
- unfold lift in H0.
- rewrite eqq2 in H0.
- generalize (scalarMult_backprop_grad_scalar σ l s d2 d3 (UnitVector n i) d)
- ; intros; simpl in H10.
- unfold df_eval_backprop_delta in H10; simpl in H10.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat));intros.
- specialize (H11 vin).
- case_eq (vartlookup d2 (s, DTfloat)); [intros | congruence].
- specialize (H10 H11 H).
- rewrite H9, H12 in H10.
- symmetry in H0.
- specialize (apply vectoro_to_ovector_forall_some_f H0); intros.
- specialize (H13 i); simpl in H13.
- generalize (backprop_deriv_fully_closed_not_none
- σ l d2
- (fun i0 : {n' : nat | n' < n} => (d * UnitVector n i i0)%R)); intros.
- specialize (H14 H2); specialize (H10 H14).
- generalize (backprop_deriv_fully_closed_not_none
- σ l d3 (UnitVector n i)); intros.
- specialize (H15 H2); specialize (H10 H15).
- unfold lift in H10.
- match_option_in H10; [|congruence]; f_equal.
- replace
- (fun j : @sig nat (fun n' : nat => lt n' n) =>
- Rmult d
- (@coerce (Vector R n) (Vector R n)
- (@df_eval_backward_gen_top_obligation_2
- floatish_R UnitAnn
- (DTVector n) (@VectorScalMult floatish_R UnitAnn n ann x l) n
- (@eq_refl definition_function_types (DTVector n)) i)
- (@UnitVector floatish_R n i) j))
- with
- (fun i0 : {n' : nat | n' < n} => (d * UnitVector n i i0)%R).
- rewrite eqq4.
- rewrite (split_subvar d2 d7 d0 d6); trivial; f_equal.
- match_option_in H10; inversion H10.
- match_option_in H13.
- invcs H7; invcs H13.
- rewrite H17.
- destruct i; simpl in eqq6.
- rewrite eqq6 in eqq5.
- invcs eqq5.
- lra.
- apply FunctionalExtensionality.functional_extensionality; intros.
- now destruct i; simpl.
- }
- + symmetry.
- destruct i; simpl.
- apply vsum_unitvector.
- + specialize (H0 d3).
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)); intros.
- specialize (H7 vin).
- specialize (H0 H7 H2).
- rewrite eqq2 in H0.
- unfold lift in H0.
- symmetry in H0.
- case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence].
- rewrite H8 in H0; simpl in H0.
- specialize (apply vectoro_to_ovector_exists_None H0); intros.
- destruct H9.
- generalize (backprop_deriv_fully_closed_not_none
- σ l d3
- (coerce
- (df_eval_backward_gen_top_obligation_2
- UnitAnn (DTVector n) l n eq_refl x0)
- (UnitVector n x0))); intros.
- specialize (H9 H2).
- match_option_in e.
- tauto.
- + unfold lift in H.
- generalize (backprop_deriv_fully_closed_not_none
- σ x grad_env 1%R); intros.
- specialize (H7 H1).
- match_option_in H.
- tauto.
- - Case "MatrixScalMult"%string.
- destruct closed.
- specialize (H grad_env vin H1).
- simpl in *.
- generalize (eval_fully_closed_not_none σ x); intros.
- specialize (H3 H1).
- match_case; [intros|tauto].
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d0 eqq0|tauto].
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H5 H2).
- case_eq (df_eval σ l); [intros|tauto].
- match_option
- ; rewrite eqq in H
- ; simpl in *; rewrite eqq0 in H.
- case_eq (df_eval_backprop_deriv σ x grad_env 1%R)
- ; [intros ? eqq1 | intros eqq1]
- ; rewrite eqq1 in H
- ; simpl in *
- ; try discriminate.
- match_option; unfold lift; symmetry.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- replace
- (@msum floatish_R n m
- (fun (i1 : @sig nat (fun n' : nat => lt n' n))
- (j : @sig nat (fun m' : nat => lt m' m)) =>
- Rmult (d1 i1 j)
- (@coerce (Matrix R n m) (Matrix R n m)
- (@df_eval_backward_gen_top_obligation_3 floatish_R UnitAnn
- (DTMatrix n m) (@MatrixScalMult floatish_R UnitAnn n m ann x l) n m
- (@eq_refl definition_function_types (DTMatrix n m)) i i0)
- (@UnitMatrix floatish_R n m i i0) i1 j)))
- with (d1 i i0).
- generalize (scalarMult_backprop_grad_scalar
- σ x s grad_env grad_env 1%R (d1 i i0)); intros; simpl in H7.
- unfold df_eval_backprop_delta in H7.
-
- rewrite eqq1 in H7; simpl in H7.
- specialize (H7 vin vin).
- rewrite eqq0 in H7.
- generalize (backprop_deriv_fully_closed_not_none
- σ x grad_env (d1 i i0 * 1)%R ); intros.
- specialize (H8 H1); specialize (H7 H8).
- cut_to H7; try discriminate.
- invcs H.
- { specialize (H0 d3).
- cut_to H0;
- [| now apply (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat))
- | apply H2].
- unfold lift in H7; simpl in H7.
- match_option_in H7.
- replace (d1 i i0) with (d1 i i0 * 1)%R by lra.
- rewrite eqq3.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)); intros.
- specialize (H vin).
- case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence].
- rewrite H9 in H0.
- unfold lift in H0.
- rewrite eqq2 in H0.
- generalize (scalarMult_backprop_grad_scalar σ l s d2 d3 (UnitMatrix n m i i0) d)
- ; intros; simpl in H10.
- unfold df_eval_backprop_delta in H10; simpl in H10.
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq3 (s, DTfloat));intros.
- specialize (H11 vin).
- case_eq (vartlookup d2 (s, DTfloat)); [intros | congruence].
- specialize (H10 H11 H).
- rewrite H9, H12 in H10.
- symmetry in H0.
- specialize (apply vectoro_to_ovector_forall_some_f H0); intros.
- specialize (H13 i); simpl in H13.
- specialize (apply vectoro_to_ovector_forall_some_f H13); intros.
- specialize (H14 i0); simpl in H14.
- generalize (backprop_deriv_fully_closed_not_none
- σ l d2
- (fun (i1 : {n' : nat | n' < n}) (j : {m' : nat | m' < m}) =>
- (d * UnitMatrix n m i i0 i1 j)%R)); intros.
- specialize (H15 H2); specialize (H10 H15).
- generalize (backprop_deriv_fully_closed_not_none
- σ l d3 (UnitMatrix n m i i0)); intros.
- specialize (H16 H2); specialize (H10 H16).
- unfold lift in H10.
- match_option_in H10; [|congruence]; f_equal.
- replace
- (fun (i1 : @sig nat (fun n' : nat => lt n' n))
- (j : @sig nat (fun m' : nat => lt m' m)) =>
- Rmult
- (@coerce (Matrix R n m) (Matrix R n m)
- (@df_eval_backward_gen_top_obligation_3
- floatish_R UnitAnn
- (DTMatrix n m) (@MatrixScalMult floatish_R UnitAnn n m ann x l) n m
- (@eq_refl definition_function_types (DTMatrix n m)) i i0)
- (@UnitMatrix floatish_R n m i i0) i1 j) d)
- with
- (fun i1 j => (d * UnitMatrix n m i i0 i1 j)%R).
- rewrite eqq4.
- rewrite (split_subvar d2 d7 d0 d6); trivial; f_equal.
- match_option_in H10; inversion H10.
- match_option_in H14.
- rewrite H18.
- invcs H7; invcs H14.
- destruct i; destruct i0; simpl in eqq6.
- rewrite eqq6 in eqq5.
- invcs eqq5.
- rewrite H17.
- lra.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- destruct i; destruct i0; simpl.
- lra.
- }
- + symmetry.
- destruct i; destruct i0; simpl.
- apply msum_unitmatrix.
- + specialize (H0 d3).
- generalize (df_eval_backprop_deriv_preserves_lookup_not_none eqq1 (s, DTfloat)); intros.
- specialize (H7 vin).
- specialize (H0 H7 H2).
- rewrite eqq2 in H0.
- unfold lift in H0.
- symmetry in H0.
- case_eq (vartlookup d3 (s, DTfloat)); [intros | congruence].
- rewrite H8 in H0; simpl in H0.
- specialize (apply vectoro_to_ovector_exists_None H0); intros.
- destruct H9.
- specialize (apply vectoro_to_ovector_exists_None e); intros.
- destruct H9.
- generalize (backprop_deriv_fully_closed_not_none
- σ l d3
- (coerce
- (df_eval_backward_gen_top_obligation_3
- UnitAnn (DTMatrix n m) l n m eq_refl x0 x1)
- (UnitMatrix n m x0 x1))); intros.
- specialize (H9 H2).
- match_option_in e0.
- tauto.
- + unfold lift in H.
- generalize (backprop_deriv_fully_closed_not_none
- σ x grad_env 1%R); intros.
- specialize (H7 H1).
- match_option_in H.
- tauto.
- - Case "VectorApply"%string.
- destruct closed.
- generalize (eval_fully_closed_not_none (mk_env_entry (v, DTfloat) (0%R) :: nil) s0); intros.
- specialize (H4 H2).
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H5 H3).
- match_case; [intros|tauto].
- specialize (H1 grad_env vin H3).
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto].
- rewrite eqq1 in H1; simpl in H1.
- match_option
- ; rewrite eqq in H1; unfold lift in H1; symmetry in H1.
- + specialize (apply vectoro_to_ovector_forall_some_f H1); intros; simpl in H7.
- unfold lift; simpl.
- match_option.
- * specialize (vectoro_to_ovector_forall_some_f eqq0); intros; simpl in H8.
- symmetry.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H8 i).
- specialize (H7 i).
- destruct i.
- simpl.
- match_nested_case.
- -- assert ( df_eval_backprop_deriv σ l grad_env v1 <> None).
- apply backprop_deriv_fully_closed_not_none; trivial.
- match_option; [|tauto].
- f_equal.
- simpl in H7.
- match_option_in H8.
- specialize (vectoro_to_ovector_forall_some_f eqq2); intros.
- assert (H10c := H10).
- specialize (H10 (exist _ x l0)).
- rewrite vmap_nth in H10; simpl in H10.
- match_option_in H10; simpl in H10.
- unfold UnitVector in H10; simpl in H10.
- destruct (equiv_dec x x); [|congruence].
- invcs H8; rewrite H12.
- assert (v1 =
- scalarMult (DTVector n) d4
- (UnitVector n (exist (fun n' : nat => n' < n) x l0))).
- ++ unfold scalarMult; simpl.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H10c x0); simpl in H10c.
- rewrite vmap_nth in H10c; simpl in H10c.
- match_option_in H10c.
- invcs H10c.
- destruct x0.
- unfold UnitVector; simpl.
- destruct (equiv_dec x0 x).
- ** red in e0.
- subst.
- erewrite index_pf_irrel in eqq6.
- rewrite eqq5 in eqq6.
- invcs eqq6; lra.
- ** lra.
- ++ replace (1 * d4)%R with d4 in H10 by lra.
- generalize (scalarMult_backprop_grad_scalar
- σ l s grad_env grad_env
- (UnitVector n (exist (fun n' : nat => n' < n) x l0)) d4)
- ; intros.
- simpl in H11; cut_to H11; trivial; try congruence.
- ** unfold df_eval_backprop_delta in H11.
- rewrite eqq1 in H11.
- unfold lift in H11; simpl in H11.
- rewrite H8 in eqq3.
- unfold scalarMult in eqq3; simpl in eqq3.
- match_option_in H7.
- rewrite eqq3, eqq6 in H11.
- invcs H11; rewrite H14.
- invcs H7; rewrite H11.
- generalize
- (df_eval_deriv_genvar_same
- [mk_env_entry (v, DTfloat)
- (d (exist (fun n' : nat => n' < n) x l0) )]
- s0 v); simpl; intros.
- specialize (H7 H H2).
- unfold lift, df_eval_deriv_gen_top in H7; simpl in H7.
- rewrite <- H12.
- unfold mk_genvar_env in eqq4; simpl in eqq4.
- rewrite eqq4, eqq5 in H7.
- invcs H7; lra.
- ** apply backprop_deriv_fully_closed_not_none; trivial.
- ** apply backprop_deriv_fully_closed_not_none; trivial.
- -- apply vectoro_to_ovector_exists_None in eqq2; destruct eqq2.
- rewrite vmap_nth in e; simpl in e.
- assert (df_eval [mk_env_entry (v, DTfloat) (d x0)]
- (df_deriv s0 (v, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- apply fully_closed_deriv; trivial.
- match_option_in e; tauto.
- * apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0.
- assert (df_eval_deriv_genvar [mk_env_entry (v, DTfloat) (d x)] s0
- [mk_env_entry (v, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- match_option_in e; tauto.
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- - Case "MatrixApply"%string.
- destruct closed.
- generalize (eval_fully_closed_not_none (mk_env_entry (v, DTfloat) (0%R) :: nil) s0); intros.
- specialize (H4 H2).
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H5 H3).
- match_case; [intros|tauto].
- specialize (H1 grad_env vin H3).
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto].
- rewrite eqq1 in H1; simpl in H1.
- match_option
- ; rewrite eqq in H1; unfold lift in H1; symmetry in H1.
- + specialize (apply vectoro_to_ovector_forall_some_f H1); intros; simpl in H7.
- unfold lift; simpl.
- match_option.
- * specialize (apply vectoro_to_ovector_forall_some_f eqq0); intros; simpl in H8.
- symmetry.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- apply vectoro_to_ovector_forall_some_b_strong; intros.
- specialize (H8 i).
- specialize (H7 i).
- match_nested_case.
- -- assert ( df_eval_backprop_deriv σ l grad_env m1 <> None).
- apply backprop_deriv_fully_closed_not_none; trivial.
- match_option; [|tauto].
- f_equal.
- simpl in H7.
- specialize (vectoro_to_ovector_forall_some_f H8); intros.
- specialize (vectoro_to_ovector_forall_some_f H7); intros.
- specialize (H10 i0); simpl in H10.
- match_option_in H10.
- specialize (vectoro_to_ovector_forall_some_f eqq2); intros.
- assert (H12c := H12).
- specialize (H12 i); simpl in H12.
- specialize (vectoro_to_ovector_forall_some_f H12); intros.
- specialize (H13 i0); simpl in H13; unfold mmap in H13.
- rewrite vmap_nth in H13; simpl in H13.
- rewrite vmap_nth in H13; simpl in H13.
- destruct i; destruct i0; simpl in H13.
- simpl in *.
- unfold matrix_zip in H13.
- rewrite vmap_nth in H13; simpl in H13.
- match_option_in H13; simpl in H13.
- unfold UnitMatrix in H13; simpl in H13.
- destruct (equiv_dec x x); [|congruence].
- destruct (equiv_dec x0 x0); [|congruence].
- invcs H13; invcs H10.
- assert (m1 =
- scalarMult (DTMatrix m n) d4
- (UnitMatrix m n
- (exist (fun n' : nat => n' < m) x l0)
- (exist (fun n' : nat => n' < n) x0 l1))).
- ++ unfold scalarMult; simpl.
- apply FunctionalExtensionality.functional_extensionality; intros.
- apply FunctionalExtensionality.functional_extensionality; intros.
- specialize (H12c x1); simpl in H12c.
- specialize (vectoro_to_ovector_forall_some_f H12c); intros.
- specialize (H10 x2); simpl in H10; unfold mmap in H10.
-
- rewrite vmap_nth in H10; simpl in H10.
- rewrite vmap_nth in H10; simpl in H10.
- unfold matrix_zip in H10.
- rewrite vmap_nth in H10; simpl in H10.
- match_option_in H10.
- invcs H10.
- destruct x1; destruct x2.
- unfold UnitMatrix; simpl.
- destruct (equiv_dec x1 x).
- ** red in e1; subst.
- destruct (equiv_dec x2 x0).
- --- red in e1; subst.
- rewrite index_pf_irrel with (pf2 := l0) in eqq6.
- rewrite index_pf_irrel with (pf1 := l3) (pf2 := l1) in eqq6.
- rewrite eqq5 in eqq6.
- invcs eqq6; lra.
- --- lra.
- ** lra.
- ++ replace (1 * d4)%R with d4 in H15 by lra.
- generalize (scalarMult_backprop_grad_scalar
- σ l s grad_env grad_env
- (UnitMatrix m n (exist (fun n' : nat => n' < m) x l0)
- (exist (fun n' : nat => n' < n) x0 l1)) d4)
- ; intros.
- simpl in H13; cut_to H13; trivial; try congruence.
- ** unfold df_eval_backprop_delta in H13.
- rewrite eqq1 in H13.
- unfold lift in H13; simpl in H13.
- rewrite H10 in eqq3.
- unfold scalarMult in eqq3; simpl in eqq3.
- specialize (H11 (exist (fun n' : nat => n' < n) x0 l1)).
- simpl in H11.
- match_option_in H11.
- rewrite eqq3, eqq6 in H13.
-
- generalize
- (df_eval_deriv_genvar_same
- [mk_env_entry (v, DTfloat)
- (d (exist (fun n' : nat => n' < m) x l0)
- (exist (fun n' : nat => n' < n) x0 l1))]
- s0 v).
- simpl; intros.
- specialize (H16 H H2).
- unfold lift, df_eval_deriv_gen_top in H16; simpl in H16.
- unfold mk_genvar_env in eqq4; simpl in eqq4.
- rewrite eqq4, eqq5 in H16.
- invcs H13; rewrite H18.
- invcs H16.
- invcs H11; rewrite H15; lra.
- ** apply backprop_deriv_fully_closed_not_none; trivial.
- ** apply backprop_deriv_fully_closed_not_none; trivial.
- -- unfold matrixo_to_omatrix in eqq2.
- apply vectoro_to_ovector_exists_None in eqq2; destruct eqq2.
- apply vectoro_to_ovector_exists_None in e; destruct e.
- unfold mmap in e.
- rewrite vmap_nth in e; simpl in e.
- rewrite vmap_nth in e; simpl in e.
- rewrite matrix_zip_m_n in e.
- destruct i; destruct i0.
- simpl in e.
- assert (df_eval [mk_env_entry (v, DTfloat)
- (d x x0)]
- (df_deriv s0 (v, DTfloat)) <> None).
- apply eval_fully_closed_not_none.
- apply fully_closed_deriv; trivial.
- match_option_in e; tauto.
- * unfold matrixo_to_omatrix in eqq0.
- apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0.
- apply vectoro_to_ovector_exists_None in e; destruct e.
- assert (df_eval_deriv_genvar [mk_env_entry (v, DTfloat) (d x x0)] s0
- [mk_env_entry (v, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- match_option_in e; tauto.
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- - Case "VLossfun"%string.
- destruct closed.
- generalize (eval_fully_closed_not_none
- (mk_env_entry (v1, DTfloat) (0%R) ::
- (mk_env_entry (v2, DTfloat) (0%R) :: nil)) s0); intros.
- specialize (H4 H2).
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H5 H3).
- match_case; [intros|tauto].
- specialize (H1 grad_env vin H3).
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto].
- rewrite eqq1 in H1; simpl in H1.
- match_option
- ; rewrite eqq in H1; unfold lift in H1; symmetry in H1.
- + specialize (vectoro_to_ovector_forall_some_f H1); intros; simpl in H7.
- unfold lift; simpl.
- match_nested_case.
- * specialize (vectoro_to_ovector_forall_some_f eqq0); intros; simpl in H8.
- symmetry.
- match_nested_case.
- -- specialize (vectoro_to_ovector_forall_some_f eqq2); intros; simpl in H9.
- assert ( df_eval_backprop_deriv σ l grad_env v0 <> None).
- apply backprop_deriv_fully_closed_not_none; trivial.
- match_option; [|tauto].
- unfold snd in d1; simpl in d1.
- assert (forall i : {n' : nat | n' < n},
- df_eval [mk_env_entry (v1, DTfloat) (d i); mk_env_entry (v2, DTfloat) (r i)]
- (df_deriv s0 (v1, DTfloat))
- = Some (v0 i)).
- intros ; specialize (H9 i)
- ; rewrite vmap_nth in H9
- ; simpl in H9
- ; match_option_in H9
- ;inversion H9 ;f_equal; lra.
- generalize (df_eval_backprop_delta_by_unit_partvec σ l s grad_env v0 v).
- intros.
- specialize (H12 H3 vin).
- cut_to H12.
- ++ unfold df_eval_backprop_delta, lift in H12.
- now rewrite eqq1, eqq3 in H12.
- ++ intros.
- replace (@scaleUnitVector (@float floatish_R) n i (v0 i) (IZR Z0)) with
- (scalarMult (DTVector n) (v0 i) (UnitVector n i)) by
- (unfold scalarMult, scaleUnitVector, UnitVector;
- apply FunctionalExtensionality.functional_extensionality; intros;
- simpl; destruct (equiv_dec (` x) (` i)); lra).
- rewrite scalarMult_backprop_grad_scalar with (grad_env2 := grad_env); trivial.
- ** unfold df_eval_backprop_delta.
- rewrite eqq1.
- specialize (H7 i); simpl in H7.
- specialize (H8 i); simpl in H8.
- specialize (H9 i); simpl in H9.
- generalize
- (df_eval_deriv_genvar_same
- [mk_env_entry (v1, DTfloat) (d i); mk_env_entry (v2, DTfloat) (r i)]
- s0 v1); simpl; intros.
- specialize (H13 H H2).
- unfold df_eval_deriv_gen_top in H13; simpl in H13.
- unfold lift in H13; simpl in H13.
- match_option_in H13.
- --- unfold mk_genvar_env in H8; simpl in H8.
- rewrite eqq4 in H8.
- specialize (H11 i).
- destruct i; simpl in H7.
- match_option_in H7.
- unfold lift; f_equal.
- unfold scalarMult; simpl.
- invcs H8.
- invcs H7.
- rewrite H14.
- rewrite <- H13 in H11.
- invcs H11.
- lra.
- --- assert (df_eval_deriv_genvar
- [mk_env_entry (v1, DTfloat) (d i);
- mk_env_entry (v2, DTfloat) (r i)] s0
- [mk_env_entry (v1, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- ** apply backprop_deriv_fully_closed_not_none; trivial.
- ** apply backprop_deriv_fully_closed_not_none; trivial.
- -- apply vectoro_to_ovector_exists_None in eqq2.
- destruct eqq2.
- rewrite vmap_nth in e; simpl in e.
- match_option_in e.
- assert (df_eval [mk_env_entry (v1, DTfloat) (d x);
- mk_env_entry (v2, DTfloat) (r x)]
- (df_deriv s0 (v1, DTfloat)) <> None).
- apply eval_fully_closed_not_none; simpl.
- apply fully_closed_deriv; trivial.
- tauto.
- * apply vectoro_to_ovector_exists_None in eqq0.
- destruct eqq0.
- match_option_in e.
- assert (df_eval_deriv_genvar
- [mk_env_entry (v1, DTfloat) (d x);
- mk_env_entry (v2, DTfloat) (r x)] s0
- [mk_env_entry (v1, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- - Case "MLossfun"%string.
- destruct closed.
- generalize (eval_fully_closed_not_none
- (mk_env_entry (v1, DTfloat) (0%R) ::
- (mk_env_entry (v2, DTfloat) (0%R) :: nil)) s0); intros.
- specialize (H4 H2).
- generalize (eval_fully_closed_not_none σ l); intros.
- specialize (H5 H3).
- match_case; [intros|tauto].
- specialize (H1 grad_env vin H3).
- case_eq (vartlookup grad_env (s, DTfloat)); [intros d1 eqq1 | tauto].
- rewrite eqq1 in H1; simpl in H1.
- match_option
- ; rewrite eqq in H1; unfold lift in H1; symmetry in H1.
- + specialize (vectoro_to_ovector_forall_some_f H1); intros; simpl in H7.
- unfold lift; simpl.
- match_nested_case.
- * specialize (vectoro_to_ovector_forall_some_f eqq0); intros; simpl in H8.
- symmetry.
- match_nested_case.
- -- specialize (vectoro_to_ovector_forall_some_f eqq2); intros; simpl in H9.
- assert ( df_eval_backprop_deriv σ l grad_env m1 <> None).
- apply backprop_deriv_fully_closed_not_none; trivial.
- match_option; [|tauto].
- unfold snd in d1; simpl in d1.
- assert (forall (i : {n' : nat | n' < m}) (j : {m' : nat | m' < n}),
- match
- df_eval [mk_env_entry (v1, DTfloat) (d i j);
- mk_env_entry (v2, DTfloat) (r i j)]
- (df_deriv s0 (v1, DTfloat))
- with
- | Some se => Some (1 * se / IZR (Z.of_nat n))%R
- | None => None
- end = Some (m1 i j)); intros.
- ++ specialize (H9 i).
- specialize (vectoro_to_ovector_forall_some_f H9); intros.
- specialize (H11 j); simpl in H11; unfold mmap in H11.
- do 2 rewrite vmap_nth in H11; simpl in H11.
- unfold matrix_zip, vector_zip in H11; simpl in H11.
- rewrite vmap_nth in H11; simpl in H11.
- match_option; rewrite eqq4 in H11; apply H11.
- ++ generalize (df_eval_backprop_delta_by_unit_partmat
- σ l s grad_env m1
- (mmap (fun u => u / IZR (Z.of_nat n))%R m0)); intros.
- specialize (H12 H3 vin).
- cut_to H12.
- ** unfold df_eval_backprop_delta, lift in H12.
- rewrite eqq1, eqq3 in H12.
- replace ((@msum floatish_R m n m0) / IZR (Z.of_nat n))%R with
- (msum (mmap (fun u : R => (u / IZR (Z.of_nat n))%R) m0)).
- --- apply H12.
- --- now rewrite msum_mmap_div_denom.
- ** intros.
- rewrite scalarMult_backprop_grad_scalar with (grad_env2 := grad_env)
- ; trivial.
- --- unfold df_eval_backprop_delta.
- rewrite eqq1.
- specialize (H7 i); simpl in H7.
- specialize (H8 i); simpl in H8.
- specialize (H11 i j); simpl in H11.
- specialize (vectoro_to_ovector_forall_some_f H7); intros.
- specialize (H13 j); simpl in H13.
- specialize (vectoro_to_ovector_forall_some_f H8); intros.
- specialize (H14 j); simpl in H14.
- generalize
- (df_eval_deriv_genvar_same
- [mk_env_entry (v1, DTfloat) (d i j);
- mk_env_entry (v2, DTfloat) (r i j)]
- s0 v1); simpl; intros.
- specialize (H15 H H2).
- unfold df_eval_deriv_gen_top in H15; simpl in H15.
- unfold lift in H15; simpl in H15.
- match_option_in H15.
- +++ unfold mk_genvar_env in H14; simpl in H14.
- rewrite eqq4 in H14.
- destruct i; destruct j; simpl in H13.
- match_option_in H13.
- unfold lift; f_equal.
- invcs H13.
- invcs H14.
- rewrite H17.
- unfold mmap.
- rewrite vmap_nth.
- rewrite vmap_nth.
- rewrite <- H16.
- rewrite <- H15 in H11.
- invcs H11.
- lra.
- +++ assert (df_eval_deriv_genvar
- [mk_env_entry (v1, DTfloat) (d i j);
- mk_env_entry (v2, DTfloat) (r i j)] s0
- [mk_env_entry (v1, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- --- apply backprop_deriv_fully_closed_not_none; trivial.
- --- apply backprop_deriv_fully_closed_not_none; trivial.
- -- unfold matrixo_to_omatrix in eqq2.
- apply vectoro_to_ovector_exists_None in eqq2.
- destruct eqq2.
- apply vectoro_to_ovector_exists_None in e.
- destruct e.
- unfold mmap in e.
- rewrite vmap_nth in e; simpl in e.
- rewrite vmap_nth in e; simpl in e.
- unfold matrix_zip,vector_zip in e; simpl in e.
- rewrite vmap_nth in e.
- match_option_in e.
- assert (df_eval [mk_env_entry (v1, DTfloat) (d x x0);
- mk_env_entry (v2, DTfloat) (r x x0)]
- (df_deriv s0 (v1, DTfloat)) <> None).
- apply eval_fully_closed_not_none; simpl.
- apply fully_closed_deriv; trivial.
- tauto.
- * unfold matrixo_to_omatrix in eqq0.
- apply vectoro_to_ovector_exists_None in eqq0.
- destruct eqq0.
- apply vectoro_to_ovector_exists_None in e.
- destruct e.
- match_option_in e.
- assert (df_eval_deriv_genvar
- [mk_env_entry (v1, DTfloat) (d x x0);
- mk_env_entry (v2, DTfloat) (r x x0)] s0
- [mk_env_entry (v1, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- + assert (df_eval_deriv_genvar σ l [mk_env_entry (s, DTfloat) 1%R] <> None).
- apply eval_deriv_genvar_fully_closed_not_none; trivial.
- tauto.
- Qed.
-
-(*
-Tactic Notation "DefinedFunction_cases" tactic(first) ident(c) :=
- first;
- [ Case_aux c "Number"%string
- | Case_aux c "Constant"%string
- | Case_aux c "DVector"%string
- | Case_aux c "DMatrix"%string
- | Case_aux c "Var"%string
- | Case_aux c "Plus"%string
- | Case_aux c "Minus"%string
- | Case_aux c "Times"%string
- | Case_aux c "Divide"%string
- | Case_aux c "Square"%string
- | Case_aux c "Exp"%string
- | Case_aux c "Log"%string
- | Case_aux c "Abs"%string
- | Case_aux c "Sign"%string
- | Case_aux c "PSign"%string
- | Case_aux c "Max"%string
- | Case_aux c "VectorDot"%string
- | Case_aux c "VectorSum"%string
- | Case_aux c "MatrixSum"%string
- | Case_aux c "VectorElem"%string
- | Case_aux c "MatrixElem"%string
- | Case_aux c "MatrixVectorMult"%string
- | Case_aux c "MatrixVectorAdd"%string
- | Case_aux c "MatrixMult"%string
- | Case_aux c "VectorPlus"%string
- | Case_aux c "VectorMinus"%string
- | Case_aux c "MatrixPlus"%string
- | Case_aux c "MatrixMinus"%string
- | Case_aux c "VectorScalMult"%string
- | Case_aux c "MatrixScalMult"%string
- | Case_aux c "VectorApply"%string
- | Case_aux c "MatrixApply"%string
- | Case_aux c "VLossfun"%string
- | Case_aux c "MLossfun"%string].
-
-
- Lemma tree_backpropeq_complete_gen {T} (env gradenv : df_env)
- (dfexpr : DefinedFunction EvalAnn T) (grad : definition_function_types_interp T) :
- forall (x : SubVar),
- let xvar := (x, DTfloat) in
- vartlookup gradenv (x,DTfloat) <> None ->
- match df_eval_tree_deriv env dfexpr xvar,
- backprop_lookup (Some gradenv) xvar,
- backprop_lookup (df_eval_tree_backprop_deriv env dfexpr gradenv grad) xvar
- with
- | Some dval, Some bval0, Some bval1 => (dval*grad + bval0)%R = bval1
- | None, _, None => True
- | _, _, _ => False
- end.
- Proof.
- simpl.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_simpl) Case.
-
- - Case "Number"%string.
- intros _ _ grad gradenv xinn.
- destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra.
- - Case "Constant"%string.
- intros _ _ grad gradenv xinn.
- destruct (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto]; lra.
- - Case "Var"%string.
- intros sv _ grad gradenv xinn.
- case_eq (vartlookup gradenv (x, DTfloat)); simpl; intros; [| tauto].
- destruct (var_dec x sv); simpl.
- + subst.
- rewrite H; simpl.
- rewrite lookup_update.
- destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (sv, DTfloat)); [| congruence].
- unfold addvar; simpl.
- rewrite H.
- lra.
- + destruct (@equiv_dec var_type _ _ _ (sv, DTfloat) (x, DTfloat)); [congruence | ].
- case_eq (vartlookup gradenv (sv, DTfloat)); simpl; intros.
- * rewrite lookup_update_neq by congruence.
- rewrite H.
- lra.
- * rewrite H.
- lra.
- - Case "Plus"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- case_eq (df_eval_tree_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_tree_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl grad gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env l gradenv grad)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr grad ge').
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' grad) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_tree_backprop_deriv env l gradenv grad ); simpl; trivial; intros.
- apply IHr.
- apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl grad gradenv xinn).
- case_eq (df_eval_tree_backprop_deriv env l gradenv grad); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Minus"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- case_eq (df_eval_tree_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_tree_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl grad gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env l gradenv grad)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr (- grad)%R ge').
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (- grad)%R) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_tree_backprop_deriv env l gradenv grad ); simpl; trivial; intros.
- apply IHr.
- apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl grad gradenv xinn).
- case_eq (df_eval_tree_backprop_deriv env l gradenv grad); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Times"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- case_eq (df_eval_tree_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_tree_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl (get_annotation r * grad)%R gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr (get_annotation l * grad)%R ge').
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (get_annotation l * grad)%R) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R ); simpl; trivial; intros.
- apply IHr.
- apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl (get_annotation r * grad)%R gradenv xinn).
- case_eq (df_eval_tree_backprop_deriv env l gradenv (get_annotation r * grad)%R); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Divide"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- case_eq (df_eval_tree_deriv env l (x, DTfloat))
- ; [intros dl eqdl | intros eqdl]
- ; rewrite eqdl in IHl.
- + case_eq (df_eval_tree_deriv env r (x, DTfloat))
- ; [intros dr eqdr | intros eqdr]
- ; rewrite eqdr in IHr.
- * specialize (IHl (grad / get_annotation r)%R gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHl
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHl
- ; [ | tauto].
- simpl in IHl.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHl
- ; [ | tauto].
- specialize (IHr (- get_annotation l / ((get_annotation r) * (get_annotation r)) * grad)%R ge').
- cut_to IHr; [ | congruence ].
- rewrite xv'eqq in IHr.
- destruct (backprop_lookup (df_eval_tree_backprop_deriv env r ge' (- get_annotation l / ((get_annotation r) * (get_annotation r)) * grad)%R) (x, DTfloat)); trivial.
- lra.
- * case_eq ( df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r)%R ); simpl; trivial; intros.
- apply IHr.
- apply (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn).
- + specialize (IHl (grad / get_annotation r)%R gradenv xinn).
- case_eq (df_eval_tree_backprop_deriv env l gradenv (grad / get_annotation r )%R); simpl; trivial; intros.
- rewrite H in IHl.
- simpl in IHl.
- generalize (df_eval_tree_backprop_deriv_preserves_lookup_not_none H (x, DTfloat) xinn); intros.
- destruct (vartlookup d (x, DTfloat)); tauto.
- - Case "Square"%string.
- intros _ e IHe grad gradenv xinn.
-
- specialize (IHe (2 * (get_annotation e) * grad)%R gradenv xinn).
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
-
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env e gradenv (2 * (get_annotation e) * grad)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
- - Case "Exp"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe (grad * exp (get_annotation e))%R gradenv xinn).
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env e gradenv (grad * exp (get_annotation e))%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
-
- - Case "Log"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe (grad / get_annotation e)%R gradenv xinn).
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env e gradenv (grad / get_annotation e)%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
-
- - Case "Abs"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe (grad * (sign (get_annotation e)))%R gradenv xinn).
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_backprop_deriv env e gradenv (grad * (sign (get_annotation e)))%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- case_eq (vartlookup ge' (x, DTfloat))
- ; [intros xv' xv'eqq | intros xv'eqq]
- ; rewrite xv'eqq in IHe
- ; [ | tauto].
- simpl.
- rewrite xv'eqq.
- lra.
- - Case "Sign"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe 0%R gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (df_eval_tree_backprop_deriv env e gradenv 0%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- simpl.
- replace (de * 0)%R with (0)%R in IHe by lra.
- replace (0 * grad)%R with (0)%R by lra.
- apply IHe.
- - Case "PSign"%string.
- intros _ e IHe grad gradenv xinn.
- specialize (IHe 0%R gradenv xinn).
- case_eq (vartlookup gradenv (x, DTfloat))
- ; [intros xv xveqq | intros xveqq]
- ; rewrite xveqq in IHe
- ; [ | tauto].
- case_eq (df_eval_tree_deriv env e (x, DTfloat))
- ; [intros de eqde | intros eqde]
- ; rewrite eqde in IHe
- ; trivial.
- case_eq (df_eval_tree_backprop_deriv env e gradenv 0%R)
- ; [intros ge' ge'eq | intros ge'eq]
- ; rewrite ge'eq in IHe
- ; [ | tauto].
- simpl in IHe.
- simpl.
- replace (de * 0)%R with (0)%R in IHe by lra.
- replace (0 * grad)%R with (0)%R by lra.
- apply IHe.
- - Case "Max"%string.
- intros _ l r IHl IHr grad gradenv xinn.
- specialize (IHl grad gradenv xinn).
- specialize (IHr grad gradenv xinn).
- destruct (Rle_dec (get_annotation l) (get_annotation r)); simpl.
- destruct (df_eval_tree_deriv env r (x, DTfloat)); simpl; trivial.
- destruct (df_eval_tree_deriv env l (x, DTfloat)); simpl; trivial.
- Qed.
-*)
-
-End real_pfs.
-
-(*
- Section FreeVariablesExample.
- (* We need to open the string scope in order to use "a" as a string. *)
- Open Scope string_scope.
- Theorem ex1 : (df_free_variables (Plus (Var "a") (Var "b"))) = "a"::"b"::nil.
- Proof.
- (* Reflexivity doesn't need syntactically identical things on either side of =.
- * It suffices that the left-hand side beta-reduced to the right-hand side. *)
- reflexivity.
- Qed.
-
- Close Scope string_scope.
-
- End FreeVariablesExample.
-*)
diff --git a/coq/NeuralNetworks/testderiv.v b/coq/NeuralNetworks/testderiv.v
deleted file mode 100644
index c2d6e9f6..00000000
--- a/coq/NeuralNetworks/testderiv.v
+++ /dev/null
@@ -1,1802 +0,0 @@
-Require Import String.
-Require Import EquivDec.
-Require Import RelationClasses.
-Require Import List.
-Require Import NPeano.
-Require Import Lra Lia.
-Require Reals.
-Require Import Eqdep_dec.
-
-Require Import Floatish.
-Require Import Utils.
-Require Import derivlemmas.
-Require Import Coquelicot.Hierarchy.
-Require Import Coquelicot.Derive.
-Require Import Coquelicot.Rcomplements.
-Require Import DefinedFunctions.
-Require FunctionalExtensionality.
-
-Set Bullet Behavior "Strict Subproofs".
-
-Section DefinedFunctions.
-
- Context {floatish_impl:floatish}.
- Local Open Scope float.
-
-Section real_pfs.
-
- Local Existing Instance floatish_R.
- Import Reals.
- Import List.
-
- (* following returns None if not-differentiable *)
- Fixpoint df_eval_deriv_exact {Ann} {T} (σ:df_env) (df:DefinedFunction Ann T) (v:var_type) : option (definition_function_types_interp T)
- := (match df with
- | Number _ _ => Some 0
- | Constant t _ x => Some
- match t return definition_function_types_interp t with
- | DTfloat => 0
- | DTVector n => ConstVector n 0
- | DTMatrix m n => ConstMatrix m n 0
- end
- | DVector n _ dfs => vectoro_to_ovector (fun i => df_eval_deriv_exact σ (dfs i) v)
- | DMatrix n m _ df => matrixo_to_omatrix (fun i j => df_eval_deriv_exact σ (df i j) v)
- | Var x _ => Some (let t:=snd x in
- match t return definition_function_types_interp t with
- | DTfloat => if x == v then 1 else 0
- | DTVector n => ConstVector n (if x == v then 1 else 0)
- | DTMatrix m n => ConstMatrix m n (if x == v then 1 else 0)
- end)
- | Plus _ l r =>
- match df_eval_deriv_exact σ l v, df_eval_deriv_exact σ r v with
- | Some le, Some lr => Some (le + lr)
- | _, _ => None
- end
- | Minus _ l r =>
- match df_eval_deriv_exact σ l v, df_eval_deriv_exact σ r v with
- | Some le, Some lr => Some (le - lr)
- | _, _ => None
- end
- | Times _ l r =>
- match df_eval σ l, df_eval_deriv_exact σ l v, df_eval σ r, df_eval_deriv_exact σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (le * rd +
- (ld * re))
- | _, _, _, _ => None
- end
- | Divide _ l r =>
- match df_eval σ l, df_eval_deriv_exact σ l v, df_eval σ r, df_eval_deriv_exact σ r v with
- | Some le, Some ld, Some re, Some rd =>
- if Feq re 0 then None else
- Some ((ld / re) - ((le * rd) / (re * re)))
- | _, _, _, _ => None
- end
- | Square _ e =>
- match df_eval σ e, df_eval_deriv_exact σ e v with
- | Some ee, Some ed => Some (2 * ee * ed)
- | _, _ => None
- end
- | Exp _ e =>
- match df_eval σ e, df_eval_deriv_exact σ e v with
- | Some ee, Some ed => Some (ed * Fexp ee)
- | _, _ => None
- end
- | Log _ e =>
- match df_eval σ e, df_eval_deriv_exact σ e v with
- | Some ee, Some ed =>
- if ee > 0 then Some (ed / ee) else None
- | _, _ => None
- end
- | Abs _ e =>
- match df_eval σ e, df_eval_deriv_exact σ e v with
- | Some ee, Some ed =>
- if Feq ee 0 then
- (if Feq ed 0 then Some 0 else None)
- else Some (ed * (sign ee))
- | _, _ => None
- end
- | Sign _ e =>
- match df_eval σ e, df_eval_deriv_exact σ e v with
- | Some ee, Some ed =>
- if Feq ee 0 then None else Some 0
- | _, _ => None
- end
- | PSign _ e =>
- match df_eval σ e, df_eval_deriv_exact σ e v with
- | Some ee, Some ed =>
- if Feq ee 0 then None else Some 0
- | _, _ => None
- end
- | Max _ l r =>
- match df_eval σ l, df_eval σ r, df_eval_deriv_exact σ l v, df_eval_deriv_exact σ r v with
- | Some le, Some re, Some ld, Some rd =>
- if Feq le re then
- (if Feq ld rd then Some ld else None)
- else
- (if le < re then Some rd else Some ld)
- | _, _, _, _=> None
- end
- | VectorElem n _ l i =>
- match (df_eval_deriv_exact σ l v) with
- | Some l' => Some (l' i)
- | _ => None
- end
- | MatrixElem m n _ l i j =>
- match (df_eval_deriv_exact σ l v) with
- | Some l' => Some (l' i j)
- | _ => None
- end
- | VectorDot n _ l r =>
- match df_eval σ l, df_eval_deriv_exact σ l v, df_eval σ r, df_eval_deriv_exact σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (vsum (fun j => (le j) * (rd j) + (ld j) * (re j)))
- | _, _, _, _ => None
- end
- | VectorSum n _ l =>
- match df_eval_deriv_exact σ l v with
- | Some ld =>
- Some (vsum ld)
- | _ => None
- end
- | MatrixSum n m _ l =>
- match df_eval_deriv_exact σ l v with
- | Some ld =>
- Some (msum ld)
- | _ => None
- end
- | VectorScalMult n _ x r =>
- match df_eval σ x, df_eval_deriv_exact σ x v, df_eval σ r, df_eval_deriv_exact σ r v with
- | Some xe, Some xd, Some re, Some rd => Some (fun j => xe * (rd j) + xd * (re j))
- | _, _, _, _ => None
- end
- | MatrixScalMult n m _ x r =>
- match df_eval σ x, df_eval_deriv_exact σ x v, df_eval σ r, df_eval_deriv_exact σ r v with
- | Some xe, Some xd, Some re, Some rd => Some (fun i j => xe * (rd i j) + xd * (re i j))
- | _, _, _, _ => None
- end
- | MatrixVectorMult n m _ l r =>
- match df_eval σ l, df_eval_deriv_exact σ l v, df_eval σ r, df_eval_deriv_exact σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (fun i => vsum (fun j => (le i j)*(rd j) + (ld i j)*(re j)))
- | _, _, _, _ => None
- end
- | MatrixVectorAdd n m _ l r =>
- match df_eval_deriv_exact σ l v, df_eval_deriv_exact σ r v with
- | Some le, Some re =>
- Some (fun i j => (le i j) + (re i))
- | _, _ => None
- end
- | MatrixMult n m p _ l r =>
- match df_eval σ l, df_eval_deriv_exact σ l v, df_eval σ r, df_eval_deriv_exact σ r v with
- | Some le, Some ld, Some re, Some rd =>
- Some (fun i k => vsum (fun j => (le i j)*(rd j k) + (ld i j)*(re j k)))
- | _, _, _, _ => None
- end
- | VectorPlus n _ l r =>
- match df_eval_deriv_exact σ l v, df_eval_deriv_exact σ r v with
- | Some l', Some r' => Some (fun i => (l' i) + (r' i))
- | _, _ => None
- end
- | VectorMinus n _ l r =>
- match df_eval_deriv_exact σ l v, df_eval_deriv_exact σ r v with
- | Some l', Some r' => Some (fun i => (l' i) - (r' i))
- | _, _ => None
- end
- | MatrixPlus n m _ l r =>
- match df_eval_deriv_exact σ l v, df_eval_deriv_exact σ r v with
- | Some l', Some r' => Some (fun i j => (l' i j) + (r' i j))
- | _, _ => None
- end
- | MatrixMinus n m _ l r =>
- match df_eval_deriv_exact σ l v, df_eval_deriv_exact σ r v with
- | Some l', Some r' => Some (fun i j => (l' i j) - (r' i j))
- | _, _ => None
- end
- | VectorApply n _ x s r =>
- match df_eval σ r, df_eval_deriv_exact σ r v with
- | Some re, Some rd =>
- vectoro_to_ovector
- (fun i =>
- let xv := (x, DTfloat):var_type in
- match df_eval_deriv_exact (cons (mk_env_entry xv (re i)) nil) s xv with
- | Some sd => Some ((rd i) * sd)
- | _ => None
- end)
- | _, _ => None
- end
- | MatrixApply n m _ x s r =>
- match df_eval σ r, df_eval_deriv_exact σ r v with
- | Some re, Some rd =>
- matrixo_to_omatrix
- (fun i j =>
- let xv := (x, DTfloat):var_type in
- match df_eval_deriv_exact (cons (mk_env_entry xv (re i j)) nil) s xv with
- | Some sd => Some ((rd i j) * sd)
- | _ => None
- end)
- | _, _ => None
- end
- | VLossfun n _ v1 v2 s l r =>
- match df_eval σ l, df_eval_deriv_exact σ l v with
- | Some le, Some ld =>
- match (vectoro_to_ovector
- (fun i =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- match df_eval_deriv_exact (cons (mk_env_entry xv1 (le i))
- (cons (mk_env_entry xv2 (r i)) nil)) s xv1 with
- | Some sd => Some ((ld i) * sd)
- | _ => None
- end)) with
- | Some vv => Some (vsum vv)
- | _ => None
- end
- | _, _ => None
- end
- | MLossfun n m _ v1 v2 s l r =>
- match df_eval σ l, df_eval_deriv_exact σ l v with
- | Some le, Some ld =>
- match (matrixo_to_omatrix
- (fun i j =>
- let xv1 := (v1, DTfloat):var_type in
- let xv2 := (v2, DTfloat):var_type in
- match df_eval_deriv_exact (cons (mk_env_entry xv1 (le i j))
- (cons (mk_env_entry xv2 (r i j)) nil)) s xv1 with
- | Some sd => Some ((ld i j) * sd)
- | _ => None
- end)) with
- | Some vv =>
- if Nat.eq_dec m 0%nat then None else Some ((msum vv) / (FfromZ (Z.of_nat m)))
- | _ => None
- end
- | _, _ => None
- end
- end).
-
-
- Definition addBinding σ v x := (mk_env_entry (v,DTfloat) x)::σ.
-
- Definition df_eval_at_point {Ann T} σ (df:DefinedFunction Ann T) v x
- := df_eval (addBinding σ v x) df.
-
- Definition df_R {Ann} (σ:df_env) (df:DefinedFunction Ann DTfloat) v : R -> R
- := fun x => match df_eval_at_point σ df v x with
- | Some y => y
- | None => 0%R
- end.
-
- Definition ex_deriv_df {Ann} σ (df:DefinedFunction Ann DTfloat) v (x:R)
- := fully_closed_over df ((v,DTfloat)::map (@projT1 _ _) σ) /\
- ex_derive (df_R σ df v) x.
-
- Definition is_deriv_df {Ann} σ (df:DefinedFunction Ann DTfloat) v (x y:R)
- := fully_closed_over df ((v,DTfloat)::map (@projT1 _ _) σ) /\
- is_derive (df_R σ df v) x y.
-
-
- Lemma eval_at_point_fully_closed_total {T} (σ:df_env) (df:DefinedFunction UnitAnn T) v x :
- let vl := (v,DTfloat)::map (fun ve => projT1 ve) σ in
- fully_closed_over df vl ->
- {d:definition_function_types_interp T | df_eval_at_point σ df v x = Some d}.
- Proof.
- intros.
- unfold df_eval_at_point.
- destruct (eval_fully_closed_total (addBinding σ v x) df) as [dd pfd]
- ; simpl; eauto.
- Defined.
-
- Lemma eval_at_point_diferentiable_total {σ} {df:DefinedFunction UnitAnn DTfloat} {v x y} :
- is_deriv_df σ df v x y ->
- {xve | df_eval_at_point σ df v x = Some xve}.
- Proof.
- intros [closed _ ].
- destruct (eval_at_point_fully_closed_total σ df v x); eauto.
- Defined.
-
- Definition eval_differentiable_at_point {σ} {df:DefinedFunction UnitAnn DTfloat} {v x y}
- (pf_deriv:is_deriv_df σ df v x y) :=
- proj1_sig (eval_at_point_diferentiable_total pf_deriv).
-
- Lemma is_derive_df_exp
- (df:DefinedFunction UnitAnn DTfloat) (σ:df_env) v x y
- (pf_deriv:is_deriv_df σ df v x y) :
- forall a,
- is_deriv_df σ (Exp a df) v x (y * exp (eval_differentiable_at_point pf_deriv)).
- Proof.
- unfold eval_differentiable_at_point.
- destruct (eval_at_point_diferentiable_total pf_deriv) as [xve eqqx]; simpl.
- unfold is_deriv_df; simpl; destruct pf_deriv as [base_closed base_deriv].
- split; trivial.
- generalize (is_derive_comp exp (df_R σ df v) x (exp xve) y)
- ; intros isd.
- unfold df_R, df_eval_at_point in *.
- eapply is_derive_ext; [ | eapply isd]; trivial.
- - intros; simpl.
- match_option; simpl.
- eelim eval_fully_closed_not_none; [ | eapply eqq].
- simpl; trivial.
- - rewrite eqqx.
- apply is_derive_exp.
- Qed.
-
- Lemma is_derive_df_mult
- (df1 df2:DefinedFunction UnitAnn DTfloat) (σ:df_env) v x y1 y2
- (pf_deriv1:is_deriv_df σ df1 v x y1)
- (pf_deriv2:is_deriv_df σ df2 v x y2) :
- forall a,
- is_deriv_df σ (Times a df1 df2) v x ((y1 * eval_differentiable_at_point pf_deriv2 + eval_differentiable_at_point pf_deriv1 * y2)).
- Proof.
- unfold eval_differentiable_at_point.
- intros.
- destruct (eval_at_point_diferentiable_total pf_deriv1) as [xve1 eqqx1]; simpl.
- destruct (eval_at_point_diferentiable_total pf_deriv2) as [xve2 eqqx2]; simpl.
- unfold is_deriv_df; simpl
- ; destruct pf_deriv1 as [base_closed1 base_deriv1]
- ; destruct pf_deriv2 as [base_closed2 base_deriv2].
- split; [tauto | ].
- generalize (is_derive_mult (df_R σ df1 v) (df_R σ df2 v) x y1 y2 base_deriv1 base_deriv2)
- ; intros HH.
- unfold df_R in *.
- rewrite eqqx1, eqqx2 in HH.
- eapply is_derive_ext; [ | eapply HH]; trivial.
- - intros; simpl.
- unfold df_eval_at_point; simpl.
- repeat match_option; unfold mult; simpl; lra.
- Qed.
-
-Tactic Notation "DefinedFunction_scalar_cases" tactic(first) ident(c) :=
- first;
- [ Case_aux c "Number"%string
- | Case_aux c "Constant"%string
- | Case_aux c "Var"%string
- | Case_aux c "Plus"%string
- | Case_aux c "Minus"%string
- | Case_aux c "Times"%string
- | Case_aux c "Divide"%string
- | Case_aux c "Square"%string
- | Case_aux c "Exp"%string
- | Case_aux c "Log"%string
- | Case_aux c "Abs"%string
- | Case_aux c "Sign"%string
- | Case_aux c "PSign"%string
- | Case_aux c "Max"%string].
-
-
-
-
- Ltac refl_simpler :=
- repeat
- match goal with
- | [H: @eq var_type _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H)
- | [H: @equiv var_type _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (var_type_UIP_refl H)
- | [H: @eq definition_function_types _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H)
- | [H: @equiv definition_function_types _ _ _ _ |- _ ] => try (inversion H; subst); rewrite (definition_function_types_UIP_refl H)
- end.
-
-
- Lemma df_R_total_plus σ a (l r:DefinedFunction UnitAnn DTfloat) v x :
- fully_closed_over l ((v,DTfloat)::map (@projT1 _ _) σ) ->
- fully_closed_over r ((v,DTfloat)::map (@projT1 _ _) σ) ->
- df_R σ (Plus a l r) v x = df_R σ l v x + df_R σ r v x.
- Proof.
- simpl.
- intros.
- destruct (eval_fully_closed_total (addBinding σ v x) l); simpl; trivial.
- destruct (eval_fully_closed_total (addBinding σ v x) r); simpl; trivial.
- unfold df_R, df_eval_at_point; simpl.
- now rewrite e, e0.
- Qed.
-
- Lemma df_R_total_minus σ a (l r:DefinedFunction UnitAnn DTfloat) v x :
- fully_closed_over l ((v,DTfloat)::map (@projT1 _ _) σ) ->
- fully_closed_over r ((v,DTfloat)::map (@projT1 _ _) σ) ->
- df_R σ (Minus a l r) v x = df_R σ l v x - df_R σ r v x.
- Proof.
- simpl.
- intros.
- destruct (eval_fully_closed_total (addBinding σ v x) l); simpl; trivial.
- destruct (eval_fully_closed_total (addBinding σ v x) r); simpl; trivial.
- unfold df_R, df_eval_at_point; simpl.
- now rewrite e, e0.
- Qed.
-
- Lemma df_R_total_times σ a (l r:DefinedFunction UnitAnn DTfloat) v x :
- fully_closed_over l ((v,DTfloat)::map (@projT1 _ _) σ) ->
- fully_closed_over r ((v,DTfloat)::map (@projT1 _ _) σ) ->
- df_R σ (Times a l r) v x = df_R σ l v x * df_R σ r v x.
- Proof.
- simpl.
- intros.
- destruct (eval_fully_closed_total (addBinding σ v x) l); simpl; trivial.
- destruct (eval_fully_closed_total (addBinding σ v x) r); simpl; trivial.
- unfold df_R, df_eval_at_point; simpl.
- now rewrite e, e0.
- Qed.
-
- Lemma df_R_total_divide σ a (l r:DefinedFunction UnitAnn DTfloat) v x :
- fully_closed_over l ((v,DTfloat)::map (@projT1 _ _) σ) ->
- fully_closed_over r ((v,DTfloat)::map (@projT1 _ _) σ) ->
- df_R σ (Divide a l r) v x = df_R σ l v x / df_R σ r v x.
- Proof.
- simpl.
- intros.
- destruct (eval_fully_closed_total (addBinding σ v x) l); simpl; trivial.
- destruct (eval_fully_closed_total (addBinding σ v x) r); simpl; trivial.
- unfold df_R, df_eval_at_point; simpl.
- now rewrite e, e0.
- Qed.
-
- Lemma floatish_sign :
- forall (x:R), sign x = FloatishOps.sign x.
- Proof.
- intros.
- unfold sign, FloatishOps.sign; simpl.
- case_eq (Rlt_dec x 0); intros; trivial.
- - match_case; intros.
- destruct s; lra.
- - match_case; intros.
- + destruct s; case_eq (Rgt_dec x 0); simpl; intros; lra.
- + lra.
- Qed.
-
- Lemma pos_sign_psign:
- forall (x:R), psign x = pos_sign x.
- Proof.
- unfold psign, pos_sign.
- intros.
- simpl.
- now destruct (Rge_dec x 0).
- Qed.
-
- Lemma Rmax_Fmax :
- forall (x y:R), Rmax x y = Fmax x y.
- unfold Rmax, Fmax; simpl; intros.
- match_case; intros.
- - case_eq (Rlt_dec x y); intros; trivial.
- lra.
- - case_eq (Rlt_dec x y); intros; trivial.
- lra.
- Qed.
-
-
- Theorem df_eval_deriv_exact_correct σ (df:DefinedFunction UnitAnn DTfloat) v (x:R) y
- : is_scalar_function df ->
- fully_closed_over df ((v,DTfloat)::map (@projT1 _ _) σ) ->
- df_eval_deriv_exact (addBinding σ v x) df (v,DTfloat) = Some y ->
- is_derive (df_R σ df v) x y.
- Proof.
- simpl.
- intros is_scalar.
- generalize is_scalar.
- revert y.
- pattern df.
- revert df is_scalar.
- DefinedFunction_scalar_cases (apply is_scalar_function_ind) Case; simpl; intros.
- - Case "Number"%string.
- unfold df_R, df_eval_at_point; simpl.
- inversion H0; subst.
- now apply (@is_derive_const R_AbsRing).
- - Case "Constant"%string.
- unfold df_R, df_eval_at_point; simpl.
- inversion H0; subst.
- now apply (@is_derive_const R_AbsRing).
- - Case "Var"%string.
- unfold df_R, df_eval_at_point; simpl.
- inversion H0; subst.
- simpl.
- unfold equiv_dec, vart_eqdec.
- destruct (vart_dec (sv, DTfloat) (v, DTfloat)).
- + refl_simpler; simpl.
- now apply (@is_derive_id R_AbsRing).
- + now apply (@is_derive_const R_AbsRing).
- - Case "Plus"%string.
- destruct H1.
- do 2 match_option_in H2.
- invcs H2.
- destruct is_scalar as [isc1 isc2].
- specialize (H _ isc1 H1 eqq).
- specialize (H0 _ isc2 H3 eqq0).
- eapply is_derive_ext.
- + intros.
- symmetry.
- now apply df_R_total_plus.
- + generalize (@is_derive_plus R_AbsRing); simpl
- ; intros HH; now apply HH.
- - Case "Minus"%string.
- destruct H1.
- do 2 match_option_in H2.
- invcs H2.
- destruct is_scalar as [isc1 isc2].
- specialize (H _ isc1 H1 eqq).
- specialize (H0 _ isc2 H3 eqq0).
- eapply is_derive_ext.
- + intros.
- symmetry.
- now apply df_R_total_minus.
- + generalize (@is_derive_minus R_AbsRing); simpl
- ; intros HH; now apply HH.
- - Case "Times"%string.
- destruct H1.
- do 4 match_option_in H2.
- invcs H2.
- destruct is_scalar as [isc1 isc2].
- specialize (H _ isc1 H1 eqq0).
- specialize (H0 _ isc2 H3 eqq2).
- eapply is_derive_ext.
- + intros.
- symmetry.
- now apply df_R_total_times.
- + generalize Derive.is_derive_mult
- ; unfold plus, mult; simpl; intros HH.
- replace (d * d2 + d0 * d1)%R with
- (d0 * (df_R σ r v x) + (df_R σ l v x) * d2).
- * apply HH; trivial.
- * unfold df_R, df_eval_at_point; simpl.
- rewrite eqq; rewrite eqq1.
- lra.
- - Case "Divide"%string.
- destruct H1.
- do 4 match_option_in H2.
- invcs H2.
- destruct is_scalar as [isc1 isc2].
- specialize (H _ isc1 H1 eqq0).
- specialize (H0 _ isc2 H3 eqq2).
- eapply is_derive_ext.
- + intros.
- symmetry.
- now apply df_R_total_divide.
- + generalize is_derive_div; simpl; intros HH.
- replace (d0 / d1 - d * d2 / (d1 * d1)) with
- ((d0 * (df_R σ r v x) - (df_R σ l v x) * d2) / ((df_R σ r v x)*((df_R σ r v x)*1))).
- * destruct (Req_EM_T d1 0); [congruence |].
- inversion H5.
- specialize (HH (df_R σ l v) (df_R σ r v) x d0 d2).
- replace (d0 / d1 - d * d2 / (d1 * d1))%R with
- ((d0 * df_R σ r v x - df_R σ l v x * d2) / (df_R σ r v x * (df_R σ r v x * 1))).
- apply HH; trivial.
- unfold df_R, df_eval_at_point; simpl.
- now rewrite eqq1.
- unfold df_R, df_eval_at_point; simpl.
- rewrite eqq; rewrite eqq1.
- now field.
- * unfold df_R, df_eval_at_point; simpl.
- rewrite eqq; rewrite eqq1.
- destruct (Req_EM_T d1 0); [congruence |].
- now field; trivial.
- - Case "Square"%string.
- do 2 match_option_in H1.
- invcs H1.
- specialize (H _ is_scalar H0 eqq0).
- assert (forall t, (Rsqr (df_R σ e v t) = df_R σ (Square a e) v t)).
- + intros; simpl.
- unfold df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v t) e); simpl; trivial.
- rewrite e0; unfold Rsqr; trivial.
- + eapply is_derive_ext.
- * intros; simpl.
- now rewrite H1.
- * replace (2 * d * d0)%R with (d0 * (2 * d))%R by lra.
- apply (@is_derive_comp R_AbsRing); trivial.
- replace (d) with (df_R σ e v x).
- -- apply is_derive_sqr.
- -- unfold df_R, df_eval_at_point; simpl.
- now rewrite eqq.
- - Case "Exp"%string.
- do 2 match_option_in H1.
- invcs H1.
- specialize (H _ is_scalar H0 eqq0).
- assert (forall t, (exp (df_R σ e v t) = df_R σ (Exp a e) v t)).
- + intros; simpl.
- unfold df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v t) e); simpl; trivial.
- rewrite e0; trivial.
- + eapply is_derive_ext.
- * intros; simpl.
- now rewrite H1.
- * apply (@is_derive_comp R_AbsRing); trivial.
- replace (d) with (df_R σ e v x).
- -- apply is_derive_exp.
- -- unfold df_R, df_eval_at_point; simpl.
- now rewrite eqq.
- - Case "Log"%string.
- do 2 match_option_in H1.
- invcs H1.
- specialize (H _ is_scalar H0 eqq0).
- assert (forall t, (ln (df_R σ e v t) = df_R σ (Log a e) v t)).
- + intros; simpl.
- unfold df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v t) e); simpl; trivial.
- rewrite e0; trivial.
- + eapply is_derive_ext.
- * intros; simpl.
- now rewrite H1.
- * destruct (Rgt_dec d 0); [|congruence].
- inversion H3.
- apply (@is_derive_comp R_AbsRing); trivial.
- replace (d) with (df_R σ e v x).
- -- apply is_derive_ln.
- unfold df_R, df_eval_at_point; simpl.
- rewrite eqq.
- lra.
- -- unfold df_R, df_eval_at_point; simpl.
- now rewrite eqq.
- - Case "Abs"%string.
- do 2 match_option_in H1.
- invcs H1.
- specialize (H _ is_scalar H0 eqq0).
- assert (forall t, (Rabs (df_R σ e v t) = df_R σ (Abs a e) v t)).
- + intros; simpl.
- unfold df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v t) e); simpl; trivial.
- rewrite e0; trivial.
- + eapply is_derive_ext.
- * intros; simpl.
- now rewrite H1.
- * destruct ( Req_EM_T d 0 ).
- -- destruct (Req_EM_T d0 0).
- ++inversion H3.
- apply is_derive_Rabs_df0.
- rewrite e1 in H.
- apply H.
- ++ congruence.
- -- inversion H3.
- apply (@is_derive_comp R_AbsRing); trivial.
- replace (d) with (df_R σ e v x)
- ; unfold df_R, df_eval_at_point; simpl; rewrite eqq; trivial.
- apply is_derive_abs; lra.
- - Case "Sign"%string.
- match_option_in H1.
- match_option_in H1.
- destruct (Req_EM_T d 0); [congruence|].
- invcs H1.
- specialize (H _ is_scalar H0 eqq0).
- assert (forall t, (sign (df_R σ e v t) = df_R σ (Sign a e) v t)).
- + intros; simpl.
- unfold df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v t) e); simpl; trivial.
- rewrite e0; trivial.
- apply floatish_sign.
- + eapply is_derive_ext.
- * intros; simpl.
- now rewrite H1.
- * replace (0)%R with (d0 * 0)%R by lra.
- apply (@is_derive_comp R_AbsRing); trivial.
- apply is_derive_sign.
- unfold df_R, df_eval_at_point; simpl.
- rewrite eqq; lra.
- - Case "PSign"%string.
- match_option_in H1.
- match_option_in H1.
- destruct (Req_EM_T d 0); [congruence|].
- invcs H1.
- specialize (H _ is_scalar H0 eqq0).
- assert (forall t, (psign (df_R σ e v t) = df_R σ (PSign a e) v t)).
- + intros; simpl.
- unfold df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v t) e); simpl; trivial.
- rewrite e0; trivial.
- apply pos_sign_psign.
- + eapply is_derive_ext.
- * intros; simpl.
- now rewrite H1.
- * replace (0)%R with (d0 * 0)%R by lra.
- apply (@is_derive_comp R_AbsRing); trivial.
- apply is_derive_psign.
- unfold df_R, df_eval_at_point; simpl.
- rewrite eqq; lra.
- - Case "Max"%string.
- do 2 match_option_in H2.
- destruct H1.
- destruct is_scalar as [isc1 isc2].
- assert (forall t, (Rmax (df_R σ l v t) (df_R σ r v t)) = df_R σ (Max a l r) v t).
- + intros; simpl.
- unfold df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v t) l); simpl; trivial.
- destruct (eval_fully_closed_total (addBinding σ v t) r); simpl; trivial.
- rewrite e, e0; trivial.
- apply Rmax_Fmax.
- + eapply is_derive_ext.
- * intros; simpl.
- now rewrite H4.
- * match_option_in H2.
- match_option_in H2.
- destruct (Req_EM_T d d0).
- -- destruct (Req_EM_T d1 d2).
- ++ invcs H2.
- apply is_derive_max_alt2.
- apply H; trivial.
- apply H0; trivial.
- ++ congruence.
- -- destruct (Rlt_dec d d0).
- ++ invcs H2.
- specialize (H _ isc1 H1 eqq1).
- specialize (H0 _ isc2 H3 eqq2).
- replace (y) with ((d1 + y + (d1-y)*sign(df_R σ l v x - df_R σ r v x))/2).
- ** apply is_derive_max; trivial.
- unfold df_R, df_eval_at_point; simpl.
- now rewrite eqq, eqq0.
- ** unfold df_R, df_eval_at_point; simpl.
- rewrite eqq, eqq0.
- unfold sign; simpl.
- destruct (total_order_T 0 (d - d0)); [destruct s; lra|lra].
- ++ invcs H2.
- specialize (H _ isc1 H1 eqq1).
- specialize (H0 _ isc2 H3 eqq2).
- replace (y) with ((y + d2 + (y - d2)*sign(df_R σ l v x - df_R σ r v x))/2).
- ** apply is_derive_max; trivial.
- unfold df_R, df_eval_at_point; simpl.
- now rewrite eqq, eqq0.
- ** unfold df_R, df_eval_at_point; simpl.
- rewrite eqq, eqq0.
- unfold sign; simpl.
- destruct (total_order_T 0 (d - d0)); [destruct s; lra | lra].
- Qed.
-
- Definition df_R_vec {Ann} {n} (σ:df_env) (df:DefinedFunction Ann (DTVector n)) v :
- R -> Vector R n
- := fun x => match df_eval_at_point σ df v x with
- | Some y => y
- | None => ConstVector n 0
- end.
-
- Definition df_R_mat {Ann} {n m} (σ:df_env) (df:DefinedFunction Ann (DTMatrix n m)) v :
- R -> Matrix R n m
- := fun x => match df_eval_at_point σ df v x with
- | Some y => y
- | None => ConstMatrix n m 0
- end.
-
- Definition df_R_gen Ann T σ :
- (DefinedFunction Ann T) -> SubVar -> R -> (definition_function_types_interp T) :=
- match T with
- | DTfloat => fun df v => df_R σ df v
- | DTVector n => fun df v => df_R_vec σ df v
- | DTMatrix n m => fun df v => df_R_mat σ df v
- end.
-
- Definition is_derive_vec {n} (f : R -> Vector R n) (x:R) (df : Vector R n) :=
- forall i, is_derive (fun x0 => f x0 i) x (df i).
-
- Definition is_derive_mat {n m} (f : R -> Matrix R n m) (x:R) (df : Matrix R n m) :=
- forall i j, is_derive (fun x0 => f x0 i j) x (df i j).
-
- Definition is_derive_gen {T} (f: R->definition_function_types_interp T) (x:R)
- (df : definition_function_types_interp T)
- :=
- (match T return (R -> definition_function_types_interp T) ->
- definition_function_types_interp T -> Prop
- with
- | DTfloat => fun f df => is_derive f x df
- | DTVector n => fun f df => is_derive_vec f x df
- | DTMatrix n m => fun f df => is_derive_mat f x df
- end) f df.
-
- Definition vec_to_nat_fun {n:nat} (v:Vector R n) (i:nat) : R :=
- match lt_dec i n with
- | left pf => v (exist _ i pf)
- | right _ => 0
- end.
-
- Lemma vec_to_nat_fun_vcons_end {n} (v : Vector R n) b :
- vec_to_nat_fun (vcons b v) n = b.
- Proof.
- unfold vec_to_nat_fun; simpl.
- destruct (lt_dec n (S n)); [ | lia].
- destruct (Nat.eq_dec n n); [ | lia].
- trivial.
- Qed.
-
- Lemma vec_to_nat_fun_vcons_nend {n} (v : Vector R n) b m (pf:(m sum_n (vec_to_nat_fun v) (n-1) = r)).
- - unfold vec_to_nat_fun, sum_n, sum_n_m, Iter.iter, Iter.iter_nat, plus, zero.
- simpl.
- lra.
- - unfold vec_to_nat_fun, sum_n, sum_n_m, Iter.iter, Iter.iter_nat, plus, zero, Datatypes.id.
- simpl; intros.
- lra.
- - intros.
- simpl.
- rewrite <- H.
- destruct n0.
- + unfold vec_to_nat_fun, sum_n, sum_n_m, Iter.iter, Iter.iter_nat, plus, zero, Datatypes.id.
- simpl.
- lra.
- + simpl.
- rewrite Nat.sub_0_r.
- rewrite sum_Sn.
- unfold plus; simpl.
- rewrite vec_to_nat_fun_vcons_end.
- rewrite Rplus_comm.
- f_equal; simpl.
- erewrite sum_n_ext_loc; [ reflexivity | ]; intros.
- apply vec_to_nat_fun_vcons_nend.
- lia.
- Qed.
-
- Lemma is_derive_vsum {n} (vf : R -> Vector R n) (x:R) (df : Vector R n) :
- is_derive_vec vf x df ->
- is_derive (fun x0 => vsum (vf x0)) x (vsum df).
- Proof.
- unfold is_derive_vec; intro.
- apply (is_derive_ext (fun x0 => sum_n (vec_to_nat_fun (vf x0)) (n-1)%nat))
- ; [intros; apply sum_n_vsum |].
- replace (@vsum floatish_R n df) with (sum_n (vec_to_nat_fun df) (n-1)%nat)
- ; [|intros; apply sum_n_vsum ].
- apply (@is_derive_sum_n R_AbsRing); intros.
- unfold vec_to_nat_fun.
- destruct (lt_dec k n).
- - apply H.
- - apply (@is_derive_const R_AbsRing).
- Qed.
-
- Lemma is_derive_msum {n m} (mf : R -> Matrix R n m) (x:R) (df : Matrix R n m) :
- is_derive_mat mf x df ->
- is_derive (fun x0 => msum (mf x0)) x (msum df).
- Proof.
- simpl.
- unfold is_derive_mat; intro.
- unfold msum.
- apply is_derive_vsum.
- unfold is_derive_vec; simpl; intro.
- rewrite vmap_nth.
- apply (is_derive_ext (fun x0 => (@vsum floatish_R m) (mf x0 i))).
- - intro; now rewrite vmap_nth.
- - apply is_derive_vsum.
- now unfold is_derive_vec.
- Qed.
-
- Theorem df_eval_deriv_exact_gen_correct {T} σ (df:DefinedFunction UnitAnn T) v (x:R) y
- : fully_closed_over df ((v,DTfloat)::map (@projT1 _ _) σ) ->
- df_eval_deriv_exact (addBinding σ v x) df (v,DTfloat) = Some y ->
- is_derive_gen (df_R_gen UnitAnn T σ df v) x y.
- Proof.
- revert σ v x.
- DefinedFunction_cases (induction T, df using DefinedFunction_ind_unit) Case
- ; simpl; intros σ v0 xx; intros.
- - Case "Number"%string.
- unfold df_R, df_eval_at_point; simpl.
- inversion H0; subst.
- now apply (@is_derive_const R_AbsRing).
- - Case "Constant"%string.
- unfold df_R_gen, is_derive_gen; simpl.
- destruct t.
- + unfold df_R; simpl; invcs H0.
- now apply (@is_derive_const R_AbsRing).
- + unfold df_R_vec; simpl; invcs H0.
- unfold is_derive_vec, ConstVector; intro.
- now apply (@is_derive_const R_AbsRing).
- + unfold df_R_mat; simpl; invcs H0.
- unfold is_derive_mat, ConstMatrix; intros.
- now apply (@is_derive_const R_AbsRing).
- - Case "DVector"%string.
- unfold is_derive_vec; intros.
- specialize (vectoro_to_ovector_forall_some_f H1); intros.
- specialize (H2 i); simpl in H2.
- rewrite vforall_forall in H0.
- assert (H0c := H0).
- specialize (H0 i).
- specialize (H i (y i) σ v0 xx H0 H2); simpl in *.
- apply (is_derive_ext (df_R σ (x i) v0) ); trivial; intros.
- unfold df_R_vec, df_R, df_eval_at_point; simpl.
- match_option.
- + match_option.
- * specialize (vectoro_to_ovector_forall_some_f eqq0); intros.
- specialize (H3 i); simpl in H3.
- now rewrite eqq in H3; invcs H3.
- * unfold ConstVector.
- apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0.
- specialize (H0c x0).
- apply (eval_fully_closed_not_none (addBinding σ v0 t) (x x0)) in H0c; tauto.
- + match_option.
- specialize (vectoro_to_ovector_forall_some_f eqq0); intros.
- specialize (H3 i); simpl in H3; congruence.
- - Case "DMatrix"%string.
- unfold is_derive_mat; intros.
- unfold matrixo_to_omatrix in H1.
- specialize (vectoro_to_ovector_forall_some_f H1); intros.
- specialize (H2 i); simpl in H2.
- specialize (vectoro_to_ovector_forall_some_f H2); intros.
- specialize (H3 j); simpl in H3.
- rewrite vforall_forall in H0.
- assert (H0c := H0).
- specialize (H0 i).
- rewrite vforall_forall in H0.
- specialize (H0 j).
- specialize (H i j (y i j) σ v0 xx H0 H3); simpl in H.
- apply (is_derive_ext (df_R σ (x i j) v0)); trivial; intros.
- unfold df_R_mat, df_R, df_eval_at_point; simpl.
- match_option.
- + match_option.
- * unfold matrixo_to_omatrix in eqq0.
- specialize (vectoro_to_ovector_forall_some_f eqq0); intros.
- specialize (H4 i); simpl in H4.
- specialize (vectoro_to_ovector_forall_some_f H4); intros.
- specialize (H5 j); simpl in H5.
- now rewrite eqq in H5; invcs H5.
- * unfold ConstMatrix.
- unfold matrixo_to_omatrix in eqq0.
- apply vectoro_to_ovector_exists_None in eqq0; destruct eqq0.
- apply vectoro_to_ovector_exists_None in e; destruct e.
- specialize (H0c x0).
- rewrite vforall_forall in H0c.
- specialize (H0c x1).
- apply (eval_fully_closed_not_none (addBinding σ v0 t) (x x0 x1)) in H0c; tauto.
- + match_option.
- unfold matrixo_to_omatrix in eqq0.
- specialize (vectoro_to_ovector_forall_some_f eqq0); intros.
- specialize (H4 i); simpl in H4.
- specialize (vectoro_to_ovector_forall_some_f H4); intros.
- specialize (H5 j); simpl in H5; congruence.
- - Case "Var"%string.
- unfold is_derive_gen.
- destruct v; unfold snd in *.
- destruct d.
- + unfold df_R_gen; simpl.
- invcs H0.
- destruct (vart_dec (s, DTfloat) (v0, DTfloat)).
- * unfold equiv_dec, vart_eqdec.
- inversion e.
- destruct (vart_dec (v0, DTfloat) (v0, DTfloat)); [|congruence].
- apply (is_derive_ext id); [|apply (@is_derive_id R_AbsRing)].
- intro.
- unfold id, df_R, df_eval_at_point.
- simpl.
- unfold equiv_dec, vart_eqdec.
- destruct (vart_dec (v0, DTfloat) (v0, DTfloat)); [|congruence].
- now refl_simpler.
- * unfold equiv_dec, vart_eqdec.
- destruct (vart_dec (s, DTfloat) (v0, DTfloat)); [congruence|].
- invcs H; [congruence | ].
- unfold df_R, df_eval_at_point.
- simpl.
- unfold equiv_dec, vart_eqdec.
- destruct (vart_dec (s, DTfloat) (v0, DTfloat)); [congruence|].
- apply (@is_derive_const R_AbsRing).
- + simpl.
- invcs H0.
- unfold is_derive_vec; simpl; intros.
- unfold ConstVector.
- unfold df_R_vec, df_eval_at_point; simpl.
- unfold equiv_dec, vart_eqdec.
- destruct (vart_dec (s, DTVector n) (v0, DTfloat)); [congruence |].
- apply (@is_derive_const R_AbsRing).
- + simpl.
- invcs H0.
- unfold is_derive_mat; simpl; intros.
- unfold ConstMatrix.
- unfold df_R_mat, df_eval_at_point; simpl.
- unfold equiv_dec, vart_eqdec.
- destruct (vart_dec (s, DTMatrix m n) (v0, DTfloat)); [congruence |].
- apply (@is_derive_const R_AbsRing).
- - Case "Plus"%string.
- destruct H.
- do 2 match_option_in H0.
- invcs H0.
- specialize (IHdf1 _ σ v0 xx H eqq).
- specialize (IHdf2 _ σ v0 xx H1 eqq0).
- eapply is_derive_ext.
- + intros.
- symmetry.
- now apply df_R_total_plus.
- + generalize (@is_derive_plus R_AbsRing); simpl
- ; intros HH; now apply HH.
- - Case "Minus"%string.
- destruct H.
- do 2 match_option_in H0.
- invcs H0.
- specialize (IHdf1 _ σ v0 xx H eqq).
- specialize (IHdf2 _ σ v0 xx H1 eqq0).
- eapply is_derive_ext.
- + intros.
- symmetry.
- now apply df_R_total_minus.
- + generalize (@is_derive_minus R_AbsRing); simpl
- ; intros HH; now apply HH.
- - Case "Times"%string.
- destruct H.
- do 4 match_option_in H0.
- invcs H0.
- specialize (IHdf1 _ σ v0 xx H eqq0).
- specialize (IHdf2 _ σ v0 xx H1 eqq2).
- eapply is_derive_ext.
- + intros.
- symmetry.
- now apply df_R_total_times.
- + generalize Derive.is_derive_mult
- ; unfold plus, mult; simpl; intros HH.
- replace (d * d2 + d0 * d1)%R with
- (d0 * (df_R σ df2 v0 xx) + (df_R σ df1 v0 xx) * d2).
- * apply HH; trivial.
- * unfold df_R, df_eval_at_point; simpl.
- rewrite eqq; rewrite eqq1.
- lra.
- - Case "Divide"%string.
- destruct H.
- do 4 match_option_in H0.
- destruct (Req_EM_T d1 0); [congruence |].
- invcs H0.
- specialize (IHdf1 _ σ v0 xx H eqq0).
- specialize (IHdf2 _ σ v0 xx H1 eqq2).
- eapply is_derive_ext.
- + intros.
- symmetry.
- now apply df_R_total_divide.
- + generalize is_derive_div; simpl; intros HH.
- replace (d0 / d1 - d * d2 / (d1 * d1))%R with
- ((d0 * (df_R σ df2 v0 xx) - (df_R σ df1 v0 xx) * d2) / ((df_R σ df2 v0 xx)*((df_R σ df2 v0 xx)*1))).
- * specialize (HH (df_R σ df1 v0) (df_R σ df2 v0) xx d0 d2).
- apply HH; trivial.
- unfold df_R, df_eval_at_point; simpl.
- now rewrite eqq1.
- * unfold df_R, df_eval_at_point; simpl.
- rewrite eqq; rewrite eqq1.
- destruct (Req_EM_T d1 0); [congruence |].
- now field; trivial.
- - Case "Square"%string.
- do 2 match_option_in H0.
- invcs H0.
- specialize (IHdf _ σ v0 xx H eqq0).
- assert (forall t, (Rsqr (df_R σ df v0 t) = df_R σ (Square ann df) v0 t)).
- + intros; simpl.
- unfold df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial.
- rewrite e; unfold Rsqr; trivial.
- + eapply is_derive_ext.
- * intros; simpl.
- now rewrite H0.
- * replace (2 * d * d0)%R with (d0 * (2 * d))%R by lra.
- apply (@is_derive_comp R_AbsRing); trivial.
- replace (d) with (df_R σ df v0 xx).
- -- apply is_derive_sqr.
- -- unfold df_R, df_eval_at_point; simpl.
- now rewrite eqq.
- - Case "Exp"%string.
- do 2 match_option_in H0.
- invcs H0.
- specialize (IHdf _ σ v0 xx H eqq0).
- assert (forall t, (exp (df_R σ df v0 t) = df_R σ (Exp ann df) v0 t)).
- + intros; simpl.
- unfold df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial.
- rewrite e; trivial.
- + eapply is_derive_ext.
- * intros; simpl.
- now rewrite H0.
- * apply (@is_derive_comp R_AbsRing); trivial.
- replace (d) with (df_R σ df v0 xx).
- -- apply is_derive_exp.
- -- unfold df_R, df_eval_at_point; simpl.
- now rewrite eqq.
- - Case "Log"%string.
- do 2 match_option_in H0.
- invcs H0.
- specialize (IHdf _ σ v0 xx H eqq0).
- assert (forall t, (ln (df_R σ df v0 t) = df_R σ (Log ann df) v0 t)).
- + intros; simpl.
- unfold df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial.
- rewrite e; trivial.
- + eapply is_derive_ext.
- * intros; simpl.
- now rewrite H0.
- * destruct (Rgt_dec d 0); [|congruence].
- inversion H2.
- apply (@is_derive_comp R_AbsRing); trivial.
- replace (d) with (df_R σ df v0 xx).
- -- apply is_derive_ln.
- unfold df_R, df_eval_at_point; simpl.
- rewrite eqq.
- lra.
- -- unfold df_R, df_eval_at_point; simpl.
- now rewrite eqq.
- - Case "Abs"%string.
- do 2 match_option_in H0.
- invcs H0.
- specialize (IHdf _ σ v0 xx H eqq0).
- assert (forall t, (Rabs (df_R σ df v0 t) = df_R σ (Abs ann df) v0 t)).
- + intros; simpl.
- unfold df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial.
- rewrite e; trivial.
- + eapply is_derive_ext.
- * intros; simpl.
- now rewrite H0.
- * destruct ( Req_EM_T d 0 ).
- -- destruct ( Req_EM_T d0 0 ).
- ++ invcs H2.
- apply is_derive_Rabs_df0.
- apply IHdf.
- ++ congruence.
- -- inversion H2.
- apply (@is_derive_comp R_AbsRing); trivial.
- replace (d) with (df_R σ df v0 xx).
- ++ replace (@FloatishOps.sign floatish_R (df_R σ df v0 xx)) with (sign (df_R σ df v0 xx)).
- ** apply is_derive_abs.
- unfold df_R, df_eval_at_point; simpl.
- rewrite eqq.
- lra.
- ** apply floatish_sign.
- ++ unfold df_R, df_eval_at_point; simpl.
- now rewrite eqq.
- - Case "Sign"%string.
- do 2 match_option_in H0.
- destruct (Req_EM_T d 0); [congruence|].
- invcs H0.
- specialize (IHdf _ σ v0 xx H eqq0).
- assert (forall t, (sign (df_R σ df v0 t) = df_R σ (Sign ann df) v0 t)).
- + intros; simpl.
- unfold df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial.
- rewrite e; trivial.
- apply floatish_sign.
- + eapply is_derive_ext.
- * intros; simpl.
- now rewrite H0.
- * replace (0)%R with (d0 * 0)%R by lra.
- apply (@is_derive_comp R_AbsRing); trivial.
- apply is_derive_sign.
- unfold df_R, df_eval_at_point; simpl.
- rewrite eqq; lra.
- - Case "PSign"%string.
- do 2 match_option_in H0.
- destruct (Req_EM_T d 0); [congruence|].
- invcs H0.
- specialize (IHdf _ σ v0 xx H eqq0).
- assert (forall t, (psign (df_R σ df v0 t) = df_R σ (PSign ann df) v0 t)).
- + intros; simpl.
- unfold df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial.
- rewrite e; trivial.
- apply pos_sign_psign.
- + eapply is_derive_ext.
- * intros; simpl.
- now rewrite H0.
- * replace (0)%R with (d0 * 0)%R by lra.
- apply (@is_derive_comp R_AbsRing); trivial.
- apply is_derive_psign.
- unfold df_R, df_eval_at_point; simpl.
- rewrite eqq; lra.
- - Case "Max"%string.
- do 4 match_option_in H0.
- destruct H.
- assert (forall t, (Rmax (df_R σ df1 v0 t) (df_R σ df2 v0 t)) = df_R σ (Max ann df1 df2) v0 t).
- + intros; simpl.
- unfold df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial.
- rewrite e, e0; trivial.
- apply Rmax_Fmax.
- + eapply is_derive_ext.
- * intros; simpl.
- now rewrite H2.
- * destruct (Req_EM_T d d0).
- -- destruct (Req_EM_T d1 d2).
- ++ invcs H0.
- apply is_derive_max_alt2.
- apply IHdf1; trivial.
- apply IHdf2; trivial.
- ++ congruence.
- -- destruct (Rlt_dec d d0).
- ++ invcs H0.
- specialize (IHdf1 _ σ v0 xx H eqq1).
- specialize (IHdf2 _ σ v0 xx H1 eqq2).
- replace (y) with ((d1 + y + (d1-y)*sign(df_R σ df1 v0 xx - df_R σ df2 v0 xx))/2).
- ** apply is_derive_max; trivial.
- unfold df_R, df_eval_at_point; simpl.
- now rewrite eqq, eqq0.
- ** unfold df_R, df_eval_at_point; simpl.
- rewrite eqq, eqq0.
- unfold sign; simpl.
- destruct (total_order_T 0 (d - d0)); [destruct s; lra | lra].
- ++ invcs H0.
- specialize (IHdf1 _ σ v0 xx H eqq1).
- specialize (IHdf2 _ σ v0 xx H1 eqq2).
- replace (y) with ((y + d2 + (y - d2)*sign(df_R σ df1 v0 xx - df_R σ df2 v0 xx))/2).
- ** apply is_derive_max; trivial.
- unfold df_R, df_eval_at_point; simpl.
- now rewrite eqq, eqq0.
- ** unfold df_R, df_eval_at_point; simpl.
- rewrite eqq, eqq0.
- unfold sign; simpl.
- destruct (total_order_T 0 (d - d0)); [destruct s; lra | lra].
- - Case "VectorDot"%string.
- do 4 match_option_in H0.
- invcs H0.
- destruct H.
- specialize (IHdf1 _ σ v0 xx H eqq0).
- specialize (IHdf2 _ σ v0 xx H0 eqq2).
- apply (is_derive_ext (fun x0 => (@vsum floatish_R n)
- (fun i => ((df_R_vec σ df1 v0 x0 i) *
- (df_R_vec σ df2 v0 x0 i))%R ))).
- + intros.
- unfold df_R_vec, df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial.
- now rewrite e, e0.
- + apply is_derive_vsum.
- unfold df_R_vec, df_R, df_eval_at_point; simpl.
- unfold is_derive_vec; intros.
- generalize is_derive_mult
- ; unfold plus, mult; simpl; intros HH.
- replace (d i * d2 i + d0 i * d1 i)%R with
- (d0 i * (df_R_vec σ df2 v0 xx i) + (df_R_vec σ df1 v0 xx i) * d2 i).
- * apply HH; trivial.
- * unfold df_R_vec, df_R, df_eval_at_point; simpl.
- rewrite eqq; rewrite eqq1.
- lra.
- - Case "VectorSum"%string.
- match_option_in H0.
- invcs H0.
- specialize (IHdf _ σ v0 xx H eqq).
- assert (forall t, (vsum (df_R_vec σ df v0 t)) = df_R σ (VectorSum ann df) v0 t).
- + intros; simpl.
- unfold df_R_vec, df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial.
- rewrite e; unfold Rsqr; trivial.
- + apply (is_derive_ext (fun x0 => (@vsum floatish_R n) (df_R_vec σ df v0 x0))).
- * intros; simpl.
- now rewrite H0.
- * apply is_derive_vsum.
- now simpl in IHdf.
- - Case "MatrixSum"%string.
- match_option_in H0.
- invcs H0.
- specialize (IHdf _ σ v0 xx H eqq).
- assert (forall t, (msum (df_R_mat σ df v0 t)) = df_R σ (MatrixSum ann df) v0 t).
- + intros; simpl.
- unfold df_R_mat, df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial.
- rewrite e; unfold Rsqr; trivial.
- + apply (is_derive_ext (fun x0 => (@msum floatish_R m n) (df_R_mat σ df v0 x0))).
- * intros; simpl.
- now rewrite H0.
- * apply is_derive_msum.
- now simpl in IHdf.
- - Case "VectorElem"%string.
- match_option_in H0.
- specialize (IHdf d σ v0 xx H eqq); simpl in IHdf.
- invcs H0.
- apply (is_derive_ext (fun x0 => (df_R_vec σ df v0) x0 i)); intros.
- + unfold df_R_vec, df_R, df_eval_at_point. simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial.
- now rewrite e.
- + unfold is_derive_vec in IHdf.
- apply IHdf.
- - Case "MatrixElem"%string.
- match_option_in H0.
- specialize (IHdf d σ v0 xx H eqq); simpl in IHdf.
- invcs H0.
- apply (is_derive_ext (fun x0 => (df_R_mat σ df v0) x0 i j)); intros.
- + unfold df_R_mat, df_R, df_eval_at_point. simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df); simpl; trivial.
- now rewrite e.
- + unfold is_derive_mat in IHdf.
- apply IHdf.
- - Case "MatrixVectorMult"%string.
- do 4 match_option_in H0.
- invcs H0.
- destruct H.
- specialize (IHdf1 _ σ v0 xx H eqq0).
- specialize (IHdf2 _ σ v0 xx H0 eqq2).
- unfold is_derive_vec; intros.
- apply (is_derive_ext
- (fun x0 => ((@matrix_vector_mult floatish_R m n)
- (df_R_mat σ df1 v0 x0)
- (df_R_vec σ df2 v0 x0)) i)).
- + intros.
- unfold df_R_vec, df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial.
- rewrite e, e0.
- unfold df_R_mat, df_R, df_eval_at_point; simpl.
- now rewrite e.
- + unfold matrix_vector_mult.
- apply is_derive_vsum.
- unfold is_derive_vec; intros.
- simpl.
- unfold df_R_vec, df_R, df_eval_at_point; simpl.
- unfold is_derive_mat; intros.
- generalize Derive.is_derive_mult
- ; unfold plus, mult; simpl; intros HH.
- replace (d i i0 * d2 i0 + d0 i i0 * d1 i0)%R with
- (d0 i i0 * (df_R_vec σ df2 v0 xx i0) + (df_R_mat σ df1 v0 xx i i0) * d2 i0)%R.
- * apply HH; trivial.
- simpl in IHdf1.
- unfold is_derive_mat in IHdf1.
- apply IHdf1.
- * unfold df_R_mat, df_R_vec, df_R, df_eval_at_point; simpl.
- rewrite eqq; rewrite eqq1.
- lra.
- - Case "MatrixVectorAdd"%string.
- do 2 match_option_in H0.
- invcs H0.
- destruct H.
- specialize (IHdf1 _ σ v0 xx H eqq).
- specialize (IHdf2 _ σ v0 xx H0 eqq0).
- unfold is_derive_mat; intros.
- apply (is_derive_ext
- (fun x0 => ((@matrix_vector_add floatish_R m n)
- (df_R_mat σ df1 v0 x0)
- (df_R_vec σ df2 v0 x0)) i j)).
- + intros.
- unfold df_R_mat, df_R_vec, df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial.
- rewrite e, e0.
- now unfold matrix_vector_add.
- + unfold matrix_vector_add; simpl.
- simpl in *.
- apply (@is_derive_plus R_AbsRing); simpl.
- apply IHdf1.
- apply IHdf2.
- - Case "MatrixMult"%string.
- do 4 match_option_in H0.
- invcs H0.
- destruct H.
- specialize (IHdf1 _ σ v0 xx H eqq0).
- specialize (IHdf2 _ σ v0 xx H0 eqq2).
- unfold is_derive_mat; intros.
- apply (is_derive_ext
- (fun x0 => ((@matrix_mult floatish_R m p n)
- (df_R_mat σ df1 v0 x0)
- (df_R_mat σ df2 v0 x0)) i j)).
- + intros.
- unfold df_R_mat, df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial.
- now rewrite e, e0.
- + unfold matrix_mult.
- apply is_derive_vsum.
- unfold is_derive_vec; intros.
- simpl.
- unfold df_R_mat, df_R, df_eval_at_point; simpl.
- generalize Derive.is_derive_mult
- ; unfold plus, mult; simpl; intros HH.
- replace (d i i0 * d2 i0 j + d0 i i0 * d1 i0 j)%R with
- (d0 i i0 * (df_R_mat σ df2 v0 xx i0 j) + (df_R_mat σ df1 v0 xx i i0) * d2 i0 j)%R.
- * apply HH; trivial.
- * unfold df_R_mat, df_R_vec, df_R, df_eval_at_point; simpl.
- rewrite eqq; rewrite eqq1.
- lra.
- - Case "VectorPlus"%string.
- destruct H.
- do 2 match_option_in H0.
- invcs H0.
- specialize (IHdf1 _ σ v0 xx H eqq).
- specialize (IHdf2 _ σ v0 xx H1 eqq0).
- unfold is_derive_vec; intro.
- simpl in *.
- unfold is_derive_vec in IHdf1.
- unfold is_derive_vec in IHdf2.
- specialize (IHdf1 i); specialize (IHdf2 i).
- apply (is_derive_ext (fun x0 => ((df_R_vec σ df1 v0 x0 i) + (df_R_vec σ df2 v0 x0 i))%R)).
- + intros.
- symmetry.
- unfold df_R_vec, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial.
- now rewrite e, e0.
- + generalize (@is_derive_plus R_AbsRing); simpl
- ; intros HH; now apply HH.
- - Case "VectorMinus"%string.
- destruct H.
- do 2 match_option_in H0.
- invcs H0.
- specialize (IHdf1 _ σ v0 xx H eqq).
- specialize (IHdf2 _ σ v0 xx H1 eqq0).
- unfold is_derive_vec; intro.
- simpl in *.
- unfold is_derive_vec in IHdf1.
- unfold is_derive_vec in IHdf2.
- specialize (IHdf1 i); specialize (IHdf2 i).
- apply (is_derive_ext (fun x0 => ((df_R_vec σ df1 v0 x0 i) - (df_R_vec σ df2 v0 x0 i))%R)).
- + intros.
- symmetry.
- unfold df_R_vec, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial.
- now rewrite e, e0.
- + generalize (@is_derive_minus R_AbsRing); simpl
- ; intros HH; now apply HH.
- - Case "MatrixPlus"%string.
- destruct H.
- do 2 match_option_in H0.
- invcs H0.
- specialize (IHdf1 _ σ v0 xx H eqq).
- specialize (IHdf2 _ σ v0 xx H1 eqq0).
- unfold is_derive_mat; intros.
- simpl in *.
- unfold is_derive_mat in IHdf1.
- unfold is_derive_mat in IHdf2.
- specialize (IHdf1 i j); specialize (IHdf2 i j).
- apply (is_derive_ext (fun x0 => ((df_R_mat σ df1 v0 x0 i j) + (df_R_mat σ df2 v0 x0 i j))%R)).
- + intros.
- symmetry.
- unfold df_R_mat, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial.
- now rewrite e, e0.
- + generalize (@is_derive_plus R_AbsRing); simpl
- ; intros HH; now apply HH.
- - Case "MatrixMinus"%string.
- destruct H.
- do 2 match_option_in H0.
- invcs H0.
- specialize (IHdf1 _ σ v0 xx H eqq).
- specialize (IHdf2 _ σ v0 xx H1 eqq0).
- unfold is_derive_mat; intros.
- simpl in *.
- unfold is_derive_mat in IHdf1.
- unfold is_derive_mat in IHdf2.
- specialize (IHdf1 i j); specialize (IHdf2 i j).
- apply (is_derive_ext (fun x0 => ((df_R_mat σ df1 v0 x0 i j) - (df_R_mat σ df2 v0 x0 i j))%R)).
- + intros.
- symmetry.
- unfold df_R_mat, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial.
- now rewrite e, e0.
- + generalize (@is_derive_minus R_AbsRing); simpl
- ; intros HH; now apply HH.
- - Case "VectorScalMult"%string.
- do 4 match_option_in H0.
- invcs H0.
- destruct H.
- specialize (IHdf1 _ σ v0 xx H eqq0).
- specialize (IHdf2 _ σ v0 xx H0 eqq2).
- unfold is_derive_vec; intros.
- apply (is_derive_ext (fun x0 => ((df_R σ df1 v0 x0) * (df_R_vec σ df2 v0 x0 i))%R)).
- + intro.
- unfold df_R_vec, df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial.
- now rewrite e, e0.
- + generalize Derive.is_derive_mult
- ; unfold plus, mult; simpl; intros HH.
- replace (d * d2 i + d0 * d1 i)%R with
- (d0 * (df_R_vec σ df2 v0 xx i) + (df_R σ df1 v0 xx) * (d2 i))%R.
- * apply HH; trivial.
- simpl in IHdf1; simpl in IHdf2.
- unfold is_derive_vec in IHdf2.
- apply IHdf2.
- * unfold df_R_vec, df_R, df_eval_at_point; simpl.
- rewrite eqq; rewrite eqq1.
- lra.
- - Case "MatrixScalMult"%string.
- do 4 match_option_in H0.
- invcs H0.
- destruct H.
- specialize (IHdf1 _ σ v0 xx H eqq0).
- specialize (IHdf2 _ σ v0 xx H0 eqq2).
- unfold is_derive_mat; intros.
- apply (is_derive_ext (fun x0 => ((df_R σ df1 v0 x0) * (df_R_mat σ df2 v0 x0 i j))%R)).
- + intro.
- unfold df_R_mat, df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df1); simpl; trivial.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial.
- now rewrite e, e0.
- + generalize Derive.is_derive_mult
- ; unfold plus, mult; simpl; intros HH.
- replace (d * d2 i j + d0 * d1 i j)%R with
- (d0 * (df_R_mat σ df2 v0 xx i j) + (df_R σ df1 v0 xx) * (d2 i j))%R.
- * apply HH; trivial.
- simpl in IHdf1; simpl in IHdf2.
- unfold is_derive_mat in IHdf2.
- apply IHdf2.
- * unfold df_R_mat, df_R, df_eval_at_point; simpl.
- rewrite eqq; rewrite eqq1.
- lra.
- - Case "VectorApply"%string.
- destruct H.
- do 2 match_option_in H0.
- specialize (vectoro_to_ovector_forall_some_f H0); intros.
- specialize (IHdf2 d0 σ v0 xx H1 eqq0).
- unfold is_derive_vec; intro.
- specialize (H2 i); simpl in H2.
- match_option_in H2; invcs H2.
- simpl in IHdf2; simpl in IHdf1.
- specialize (IHdf1 d1).
- unfold is_derive_vec in IHdf2; specialize (IHdf2 i).
- generalize
- (@is_derive_comp R_AbsRing R_NormedModule
- (df_R nil df1 v)
- (fun x0 => (df_R_vec σ df2 v0 x0 i))); simpl; intros.
- specialize (H2 xx d1 (d0 i)).
- apply (is_derive_ext (fun x : R => df_R Datatypes.nil df1 v (df_R_vec σ df2 v0 x i)))
- ; intros.
- + destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial.
- unfold df_R_vec, df_eval_at_point; simpl.
- rewrite e.
- unfold df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (mk_env_entry (v, DTfloat) (x i) :: nil) df1)
- ; simpl; trivial.
- unfold addBinding.
- rewrite e0.
- match_option.
- * specialize (vectoro_to_ovector_forall_some_f eqq2); intros.
- specialize (H3 i);simpl in H3.
- congruence.
- * apply vectoro_to_ovector_exists_None in eqq2.
- destruct eqq2.
- destruct (eval_fully_closed_total (mk_env_entry (v, DTfloat) (x x1) :: nil) df1)
- ; simpl; trivial.
- congruence.
- + apply H2.
- * unfold df_R_vec, df_R, df_eval_at_point; simpl.
- rewrite eqq.
- unfold df_R, df_eval_at_point in IHdf1.
- specialize (IHdf1 nil v (d i)).
- apply IHdf1; trivial.
- * apply IHdf2.
- - Case "MatrixApply"%string.
- destruct H.
- do 2 match_option_in H0.
- specialize (vectoro_to_ovector_forall_some_f H0); intros.
- specialize (IHdf2 d0 σ v0 xx H1 eqq0).
- unfold is_derive_mat; intros.
- specialize (H2 i); simpl in H2.
- unfold matrixo_to_omatrix in H0.
- specialize (vectoro_to_ovector_forall_some_f H2); intros.
- specialize (H3 j); simpl in H3.
- match_option_in H3; invcs H3.
- simpl in IHdf2; simpl in IHdf1.
- specialize (IHdf1 d1).
- unfold is_derive_mat in IHdf2; specialize (IHdf2 i j).
- generalize
- (@is_derive_comp R_AbsRing R_NormedModule
- (df_R nil df1 v)
- (fun x0 => (df_R_mat σ df2 v0 x0 i j)))
- ; simpl; intros.
- specialize (H3 xx d1 (d0 i j)).
- apply (is_derive_ext (fun x : R => df_R Datatypes.nil df1 v (df_R_mat σ df2 v0 x i j)))
- ; intros.
- + destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial.
- unfold df_R_mat, df_eval_at_point; simpl.
- rewrite e.
- unfold matrixo_to_omatrix.
- destruct (eval_fully_closed_total (mk_env_entry (v, DTfloat) (x i j) :: nil) df1)
- ; simpl; trivial.
- unfold df_R, df_eval_at_point, addBinding; simpl.
- rewrite e0.
- match_option.
- * specialize (vectoro_to_ovector_forall_some_f eqq2); intros.
- specialize (H4 i); simpl in H4.
- specialize (vectoro_to_ovector_forall_some_f H4); intros.
- specialize (H6 j); simpl in H6.
- congruence.
- * apply vectoro_to_ovector_exists_None in eqq2.
- destruct eqq2.
- apply vectoro_to_ovector_exists_None in e1.
- destruct e1.
- destruct (eval_fully_closed_total (mk_env_entry (v, DTfloat) (x x1 x2) :: nil) df1)
- ; simpl; trivial.
- congruence.
- + apply H3.
- * unfold df_R_mat, df_R, df_eval_at_point; simpl.
- rewrite eqq.
- specialize (IHdf1 nil v (d i j)).
- unfold df_R, df_eval_at_point in IHdf1; simpl in IHdf1.
- apply IHdf1; trivial.
- * apply IHdf2.
- - Case "VLossfun"%string.
- destruct H.
- do 3 match_option_in H0.
- invcs H0.
- specialize (vectoro_to_ovector_forall_some_f eqq1); simpl; intros.
- specialize (IHdf2 d0 σ v0 xx H1 eqq0).
- simpl in IHdf2; simpl in IHdf1.
- unfold is_derive_vec in IHdf2.
- unfold df_R, df_eval_at_point; simpl.
- generalize (is_derive_vsum
- (fun x : R =>
- match df_eval (addBinding σ v0 x) df2 with
- | Some l' =>
- match
- vectoro_to_ovector
- (fun i : {n' : nat | (n' < n)%nat} =>
- df_eval
- (mk_env_entry (v1, DTfloat) (l' i)
- :: mk_env_entry (v2, DTfloat) (r i)
- :: Datatypes.nil) df1)
- with
- | Some vv => vv
- | None => fun _ => 0%R
- end
- | None => fun _ => 0%R
- end)
- xx v); intros.
- apply (is_derive_ext
- (fun x0 : R_AbsRing =>
- vsum
- match df_eval (addBinding σ v0 x0) df2 with
- | Some l' =>
- match
- vectoro_to_ovector
- (fun i : {n' : nat | (n' < n)%nat} =>
- df_eval
- (mk_env_entry (v1, DTfloat) (l' i)
- :: mk_env_entry (v2, DTfloat) (r i)
- :: Datatypes.nil) df1)
- with
- | Some vv => vv
- | None => fun _ : {n' : nat | (n' < n)%nat} => 0
- end
- | None => fun _ : {n' : nat | (n' < n)%nat} => 0
- end)); intros.
- + destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial.
- rewrite e.
- match_option.
- apply vsum0.
- + apply H2.
- unfold is_derive_vec; simpl; intros.
- generalize
- (@is_derive_comp R_AbsRing R_NormedModule
- (df_R (addBinding nil v2 (r i)) df1 v1)
- (fun x0 => (df_R_vec σ df2 v0 x0 i))); simpl; intros.
- apply (is_derive_ext
- (fun x0 : R =>
- df_R (addBinding Datatypes.nil v2 (r i)) df1 v1 (df_R_vec σ df2 v0 x0 i)));intros.
- * unfold df_R_vec, df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial.
- rewrite e.
- destruct (eval_fully_closed_total
- (addBinding (addBinding Datatypes.nil v2 (r i)) v1 (x i)) df1); simpl; trivial.
- rewrite e0.
- match_option.
- -- specialize (vectoro_to_ovector_forall_some_f eqq2); intros.
- specialize (H4 i); simpl in H4.
- unfold addBinding in e0.
- congruence.
- -- apply vectoro_to_ovector_exists_None in eqq2.
- destruct eqq2.
- destruct (eval_fully_closed_total
- (mk_env_entry (v1, DTfloat) (x x1)
- :: mk_env_entry (v2, DTfloat) (r x1)
- :: Datatypes.nil) df1); trivial.
- congruence.
- * specialize (H0 i).
- match_option_in H0.
- inversion H0.
- apply H3; trivial.
- apply IHdf1; trivial.
- unfold addBinding, df_R_vec, df_eval_at_point; simpl.
- rewrite eqq.
- apply eqq2.
- - Case "MLossfun"%string.
- destruct H.
- do 3 match_option_in H0.
- invcs H0.
- destruct (Nat.eq_dec n 0); [congruence|].
- invcs H3.
- specialize (vectoro_to_ovector_forall_some_f eqq1); simpl; intros.
- specialize (IHdf2 d0 σ v0 xx H1 eqq0).
- simpl in IHdf2; simpl in IHdf1.
- unfold is_derive_mat in IHdf2.
- replace ((@msum floatish_R m n m0) / IZR (Z.of_nat n))%R with (/ (IZR (Z.of_nat n))%R * msum m0)%R by lra.
- apply (is_derive_ext (fun x0 => scal (Rinv (IZR (Z.of_nat n))%R)
- (scal (IZR (Z.of_nat n))%R
- (df_R σ (MLossfun ann v1 v2 df1 df2 r) v0 x0)))); intros.
- + unfold scal; simpl.
- unfold mult; simpl.
- field.
- apply IZR_neq; lia.
- + apply is_derive_scal.
- unfold df_R, df_eval_at_point; simpl.
- generalize (is_derive_msum
- (fun x : R =>
- match df_eval (addBinding σ v0 x) df2 with
- | Some l' =>
- match
- matrixo_to_omatrix
- (fun (i : {n' : nat | (n' < m)%nat})
- (j : {m' : nat | (m' < n)%nat}) =>
- df_eval
- (mk_env_entry (v1, DTfloat) (l' i j)
- :: mk_env_entry (v2, DTfloat) (r i j)
- :: Datatypes.nil) df1)
- with
- | Some vv => vv
- | None => fun _ _ => 0%R
- end
- | None => fun _ _ => 0%R
- end)
- xx m0); intros.
- apply (is_derive_ext
- (fun x0 : R_AbsRing =>
- msum
- match df_eval (addBinding σ v0 x0) df2 with
- | Some l' =>
- match
- matrixo_to_omatrix
- (fun (i : {n' : nat | (n' < m)%nat})
- (j : {m' : nat | (m' < n)%nat}) =>
- df_eval
- (mk_env_entry (v1, DTfloat) (l' i j)
- :: mk_env_entry (v2, DTfloat) (r i j)
- :: Datatypes.nil) df1)
- with
- | Some vv => vv
- | None =>
- fun (_ : {n' : nat | (n' < m)%nat})
- (_ : {m' : nat | (m' < n)%nat}) => 0
- end
- | None => fun (_ : {n' : nat | (n' < m)%nat})
- (_ : {m' : nat | (m' < n)%nat}) => 0
- end)); intros.
- * destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial.
- rewrite e.
- unfold matrixo_to_omatrix; simpl.
- match_option.
- -- unfold scal; simpl.
- unfold mult; simpl.
- generalize (msum v); unfold float; simpl; intros.
- field.
- rewrite <- INR_IZR_INZ.
- now apply INR_nzero_eq.
- -- unfold scal; simpl.
- unfold mult; simpl.
- replace (IZR (Z.of_nat n) * 0)%R with (0)%R by lra.
- unfold msum.
- erewrite vsum_ext; try apply vsum0; intro.
- rewrite vmap_nth.
- erewrite vsum_ext; try apply vsum0; intro.
- trivial.
- * apply H2.
- unfold is_derive_mat; simpl; intros.
- generalize
- (@is_derive_comp R_AbsRing R_NormedModule
- (df_R (addBinding nil v2 (r i j)) df1 v1)
- (fun x0 => (df_R_mat σ df2 v0 x0 i j))); simpl; intros.
- apply (is_derive_ext
- (fun x0 : R =>
- df_R (addBinding Datatypes.nil v2 (r i j)) df1 v1
- (df_R_mat σ df2 v0 x0 i j)));intros.
- -- unfold df_R_mat, df_R, df_eval_at_point; simpl.
- destruct (eval_fully_closed_total (addBinding σ v0 t) df2); simpl; trivial.
- rewrite e.
- destruct (eval_fully_closed_total
- (addBinding (addBinding Datatypes.nil v2 (r i j)) v1 (x i j)) df1)
- ; simpl; trivial.
- rewrite e0.
- match_option.
- ++ specialize (vectoro_to_ovector_forall_some_f eqq2); intros.
- specialize (H4 i); simpl in H4.
- specialize (vectoro_to_ovector_forall_some_f H4); intros.
- specialize (H5 j); simpl in H5.
- unfold addBinding in e0.
- congruence.
- ++ unfold matrixo_to_omatrix in eqq2.
- apply vectoro_to_ovector_exists_None in eqq2.
- destruct eqq2.
- apply vectoro_to_ovector_exists_None in e1.
- destruct e1.
- destruct (eval_fully_closed_total
- (mk_env_entry (v1, DTfloat) (x x1 x2)
- :: mk_env_entry (v2, DTfloat) (r x1 x2)
- :: Datatypes.nil) df1); trivial.
- congruence.
- -- specialize (H0 i).
- specialize (vectoro_to_ovector_forall_some_f H0); intros.
- specialize (H4 j); simpl in H4.
- match_option_in H4.
- inversion H4.
- apply H3; trivial.
- apply IHdf1; trivial.
- unfold addBinding, df_R_vec, df_eval_at_point; simpl.
- unfold df_R_mat, df_eval_at_point; simpl.
- rewrite eqq.
- apply eqq2.
- Qed.
diff --git a/coq/utils/ExtrFloatishIEEE.v b/coq/utils/ExtrFloatishIEEE.v
deleted file mode 100644
index ad4a798a..00000000
--- a/coq/utils/ExtrFloatishIEEE.v
+++ /dev/null
@@ -1,33 +0,0 @@
-Require Import Extraction.
-Require Import FloatishIEEE.
-
-(* This is assumed by fromZ *)
-Require Import ExtrOcamlZInt.
-
-Extract Constant IEEE_float => "float".
-
-Extract Inlined Constant IEEE_zero => "0.".
-Extract Inlined Constant IEEE_opp => "Float.neg".
-Extract Inlined Constant IEEE_plus => "Float.add".
-Extract Inlined Constant IEEE_minus => "Float.sub".
-Extract Inlined Constant IEEE_mult => "Float.mul".
-Extract Inlined Constant IEEE_div => "Float.div".
-Extract Inlined Constant IEEE_sqrt => "Float.sqrt".
-Extract Inlined Constant IEEE_abs => "Float.abs".
-
-
-Extract Inlined Constant IEEE_exp => "Float.exp".
-Extract Inlined Constant IEEE_ln => "Float.log".
-
-Extract Inlined Constant IEEE_pi => "Float.pi".
-Extract Inlined Constant IEEE_sin => "Float.sin".
-Extract Inlined Constant IEEE_cos => "Float.cos".
-
-Extract Inlined Constant IEEE_fromZ => "Float.of_int".
-
-Extract Inlined Constant IEEE_eq => "Float.equal".
-Extract Inlined Constant IEEE_neq => "(fun x y -> x <> y)".
-Extract Inlined Constant IEEE_lt => "(fun x y -> x < y)".
-Extract Inlined Constant IEEE_le => "(fun x y -> x <= y)".
-Extract Inlined Constant IEEE_gt => "(fun x y -> x > y)".
-Extract Inlined Constant IEEE_ge => "(fun x y -> x >= y)".
diff --git a/coq/utils/Floatish.v b/coq/utils/Floatish.v
deleted file mode 100644
index f3ee3f12..00000000
--- a/coq/utils/Floatish.v
+++ /dev/null
@@ -1,8 +0,0 @@
-Require Export FloatishDef.
-Require Export FloatishOps.
-
-Require Export FloatishInterval.
-Require Export FloatishIEEE.
-Require Export FloatishReal.
-
-Require Export FloatishRealOps.
diff --git a/coq/utils/Floatish/FloatishDef.v b/coq/utils/Floatish/FloatishDef.v
deleted file mode 100644
index b244525e..00000000
--- a/coq/utils/Floatish/FloatishDef.v
+++ /dev/null
@@ -1,52 +0,0 @@
-Require Import BinInt.
-
-Declare Scope float.
-
-Class floatish : Type :=
- {
- float : Type
- ; Fzero : float
-
- ; Fopp : float -> float
-
- ; Fplus : float -> float -> float
- ; Fminus : float -> float -> float
- ; Fmult : float -> float -> float
- ; Fdiv : float -> float -> float
-
- ; Fsqrt : float -> float
- ; Fabs : float -> float
-
- ; Fexp : float -> float
- ; Fln : float -> float
-
- ; Fsin : float -> float
- ; Fcos : float -> float
- ; Fpi : float
-
- ; FfromZ : Z -> float
-
- ; Feq : float -> float -> bool
- ; Fneq : float -> float -> bool
- ; Flt : float -> float -> bool
- ; Fle : float -> float -> bool
- ; Fgt : float -> float -> bool
- ; Fge : float -> float -> bool
- }.
-
-Notation "0" := (Fzero) : float.
-Notation "1" := (FfromZ 1) : float.
-Notation "2" := (FfromZ 2) : float.
-Notation "- x" := (Fopp x) (at level 35, right associativity) : float.
-Notation "x + y" := (Fplus x y) (at level 50, left associativity) : float.
-Notation "x - y" := (Fminus x y) (at level 50, left associativity) : float.
-Notation "x * y" := (Fmult x y) (at level 40, left associativity) : float.
-Notation "x / y" := (Fdiv x y) (at level 40, left associativity) : float.
-
-
-Notation "x ==b y" := (Feq x y) (at level 70, no associativity) : float.
-Notation "x != y" := (Fneq x y) (at level 70, no associativity) : float.
-Notation "x < y" := (Flt x y) (at level 70, no associativity) : float.
-Notation "x <= y" := (Fle x y) (at level 70, no associativity) : float.
-Notation "x > y" := (Fgt x y) (at level 70, no associativity) : float.
-Notation "x >= y" := (Fge x y) (at level 70, no associativity) : float.
diff --git a/coq/utils/Floatish/FloatishIEEE.v b/coq/utils/Floatish/FloatishIEEE.v
deleted file mode 100644
index 91a82a50..00000000
--- a/coq/utils/Floatish/FloatishIEEE.v
+++ /dev/null
@@ -1,98 +0,0 @@
-Require Import Flocq.IEEE754.BinarySingleNaN.
-Require Import BinInt.
-
-Require Import FloatishDef.
-
-Require Import Flocq.IEEE754.Binary.
-Require Import Flocq.IEEE754.Bits.
-
-
-Definition IEEE_float := binary64.
-Definition IEEE_zero : IEEE_float := B754_zero 53 1024 false.
-Definition IEEE_opp := b64_opp.
-Definition IEEE_plus := b64_plus mode_NE.
-Definition IEEE_minus := b64_minus mode_NE.
-Definition IEEE_mult := b64_mult mode_NE.
-Definition IEEE_div := b64_div mode_NE.
-Definition IEEE_sqrt := b64_sqrt mode_NE.
-Definition IEEE_abs := b64_abs.
-
-Definition IEEE_fromZ i := binary_normalize 53 1024 (eq_refl _) (eq_refl _) mode_NE i 0 false.
-
-Definition IEEE_eq (x y:IEEE_float)
- := (match b64_compare x y with
- | Some Eq => true
- | _ => false
- end).
-
-Definition IEEE_neq (x y:IEEE_float)
- := (match b64_compare x y with
- | Some Eq => false
- | Some _ => true
- | _ => false
- end).
-
-Definition IEEE_lt (x y:IEEE_float)
- := (match b64_compare x y with
- | Some Lt => true
- | _ => false
- end).
-
-Definition IEEE_le (x y:IEEE_float)
- := (match b64_compare x y with
- | Some Lt => true
- | Some Eq => true
- | _ => false
- end).
-
-
-Definition IEEE_gt (x y:IEEE_float)
- := (match b64_compare x y with
- | Some Gt => true
- | _ => false
- end).
-
-
-Definition IEEE_ge (x y:IEEE_float)
- := (match b64_compare x y with
- | Some Gt => true
- | Some Eq => true
- | _ => false
- end).
-
-(* following function will be defined only via extraction *)
-Axiom IEEE_exp : IEEE_float -> IEEE_float.
-Axiom IEEE_ln : IEEE_float -> IEEE_float.
-Axiom IEEE_pi : IEEE_float.
-Axiom IEEE_sin : IEEE_float -> IEEE_float.
-Axiom IEEE_cos : IEEE_float -> IEEE_float.
-
-Local Instance floatish_IEEE : floatish :=
- {
- float := IEEE_float
- ; Fzero := IEEE_zero
- ; Fopp := IEEE_opp
- ; Fplus := IEEE_plus
- ; Fminus := IEEE_minus
- ; Fmult := IEEE_mult
- ; Fdiv := IEEE_div
- ; Fsqrt := IEEE_sqrt
- ; Fabs := IEEE_abs
-
- ; Fexp := IEEE_exp
- ; Fln := IEEE_ln
-
- ; Fpi := IEEE_pi
- ; Fsin := IEEE_sin
- ; Fcos := IEEE_cos
-
- ; FfromZ := IEEE_fromZ
-
-
- ; Feq := IEEE_eq
- ; Fneq := IEEE_neq
- ; Flt := IEEE_lt
- ; Fle := IEEE_le
- ; Fgt := IEEE_gt
- ; Fge := IEEE_ge
- }.
diff --git a/coq/utils/Floatish/FloatishInterval.v b/coq/utils/Floatish/FloatishInterval.v
deleted file mode 100644
index 0c0516e3..00000000
--- a/coq/utils/Floatish/FloatishInterval.v
+++ /dev/null
@@ -1,82 +0,0 @@
-Require Import BinInt.
-
-Require Import Interval.Real.Xreal.
-
-Require Import Interval.Interval.Transcend.
-Require Import Interval.Float.Specific_ops.
-Require Import Interval.Float.Specific_stdz.
-Require Import Interval.Float.Basic.
-
-Require Import FloatishDef.
-
-Module F := SpecificFloat StdZRadix2.
-Module A := TranscendentalFloatFast F.
-
-Local Instance floatish_interval_gen (prec:Z) : floatish :=
- {
- float := F.type
- ; Fzero := F.zero
-
- ; Fopp := F.neg
-
- ; Fplus := F.add_slow rnd_NE prec
- ; Fminus x y := F.add_slow rnd_NE prec x (F.neg y)
- ; Fmult := F.mul rnd_NE prec
- ; Fdiv := F.div rnd_NE prec
-
- ; Fsqrt := F.sqrt rnd_NE prec
- ; Fabs := F.abs
-
- ; Fexp x := A.I.midpoint(A.exp_fast prec x)
- ; Fln x := A.I.midpoint(A.ln_fast prec x)
-
- ; Fsin x := A.I.midpoint(A.sin_fast prec x)
- ; Fcos x := A.I.midpoint(A.cos_fast prec x)
- ; Fpi := F.mul rnd_NE prec (F.fromZ 4) (A.I.midpoint (A.pi4 prec))
-
- ; FfromZ := F.fromZ
-
- ; Feq (x y:F.type)
- := (match F.cmp x y with
- | Xeq => true
- | _ => false
- end)
-
- ; Fneq (x y:F.type)
- := (match F.cmp x y with
- | Xeq => false
- | _ => true
- end)
-
- ; Flt (x y:F.type)
- := (match F.cmp x y with
- | Xlt => true
- | _ => false
- end)
-
- ; Fle (x y:F.type)
- := (match F.cmp x y with
- | Xlt => true
- | Xeq => true
- | _ => false
- end)
-
- ; Fgt (x y:F.type)
- := (match F.cmp x y with
- | Xgt => true
- | _ => false
- end)
-
- ; Fge (x y:F.type)
- := (match F.cmp x y with
- | Xgt => true
- | Xeq => true
- | _ => false
- end)
- }.
-
-
-Local Instance floatish_interval : floatish := floatish_interval_gen 53.
-
-Definition FZF (r:float) := F.nearbyint rnd_NE r.
-Definition FZFscale (n:Z) (r:float) := FZF (Fmult (FfromZ n) r).
diff --git a/coq/utils/Floatish/FloatishOps.v b/coq/utils/Floatish/FloatishOps.v
deleted file mode 100644
index b01f1a39..00000000
--- a/coq/utils/Floatish/FloatishOps.v
+++ /dev/null
@@ -1,25 +0,0 @@
-Require Import FloatishDef.
-
-Section floatish_ops.
-
- Context {floatish_impl:floatish}.
- Local Open Scope float.
-
- Definition pos_sign (e:float)
- := if e >= 0 then 1 else Fopp 1.
-
- Definition neg_sign (e:float)
- := if e <= 0 then Fopp 1 else 1.
-
- Definition sign (e:float)
- := if e < 0 then Fopp 1
- else if e > 0 then 1
- else 0.
-
- Definition Fmax (x y:float)
- := if x < y then y else x.
-
- Definition Fmin (x y:float)
- := if x > y then y else x.
-
-End floatish_ops.
diff --git a/coq/utils/Floatish/FloatishReal.v b/coq/utils/Floatish/FloatishReal.v
deleted file mode 100644
index b7bfde15..00000000
--- a/coq/utils/Floatish/FloatishReal.v
+++ /dev/null
@@ -1,33 +0,0 @@
-Require Import BinInt Reals Lra.
-
-Require Import FloatishDef.
-
-Local Instance floatish_R : floatish :=
- {
- float := R
- ; Fzero := 0%R
- ; Fopp := Ropp
- ; Fplus := Rplus
- ; Fminus := Rminus
- ; Fmult := Rmult
- ; Fdiv := Rdiv
- ; Fsqrt := sqrt
- ; Fabs := Rabs
-
- ; Fexp := exp
- ; Fln := ln
-
- ; Fpi := PI
- ; Fsin := sin
- ; Fcos := cos
-
- ; FfromZ := IZR
-
-
- ; Feq x y := if Req_EM_T x y then true else false
- ; Fneq x y := if Req_EM_T x y then false else true
- ; Flt x y := if Rlt_dec x y then true else false
- ; Fle x y := if Rle_dec x y then true else false
- ; Fgt x y := if Rgt_dec x y then true else false
- ; Fge x y := if Rge_dec x y then true else false
- }.
diff --git a/coq/utils/Floatish/FloatishRealOps.v b/coq/utils/Floatish/FloatishRealOps.v
deleted file mode 100644
index 3a20978a..00000000
--- a/coq/utils/Floatish/FloatishRealOps.v
+++ /dev/null
@@ -1,21 +0,0 @@
-Require Import BinInt Reals Lra.
-Require Import FloatishDef FloatishReal FloatishOps.
-
-Section real_pfs.
-
- Local Existing Instance floatish_R.
- Lemma Fmax_Rmax x y : Fmax x y = Rmax x y.
- Proof.
- vm_compute.
- destruct (Rlt_dec x y); destruct (Rle_dec); lra.
- Qed.
-
- Lemma Fmin_Rmin x y : Fmin x y = Rmin x y.
- Proof.
- vm_compute.
- destruct (Rgt_dec x y); destruct (Rle_dec); lra.
- Qed.
-
-End real_pfs.
-
-Hint Rewrite Fmax_Rmax Fmin_Rmin : Rarith.
diff --git a/coq/utils/nvector.v b/coq/utils/nvector.v
deleted file mode 100644
index 8e14a0d7..00000000
--- a/coq/utils/nvector.v
+++ /dev/null
@@ -1,282 +0,0 @@
-Require Import List.
-Require Import BinInt.
-Require Import Lia.
-Require Import LibUtils.
-
-Require Import VectorDef.
-Require Vector.
-
-Section Vector.
-
-Definition vector (T:Type) (n:nat) := Vector.t T n.
-
-Definition vnil {T} : vector T 0 := nil T.
-
-Definition vcons {T} (n:nat) (c:T) (v:vector T n) : vector T (S n) :=
- cons T c n v.
-
-Definition vappend {T} (n1 n2:nat) (v1:vector T n1) (v2:vector T n2) : vector T (n1 + n2)
- := append v1 v2.
-
-Definition vmap {A B} {n} (f:A->B) (v : vector A n) : vector B n := map f v.
-
-Definition vhd {T} {n:nat} (v : vector T (S n)):T := hd v.
-
-Definition vtl {T} {n:nat} (v : vector T (S n)) : vector T n := tl v.
-
-Definition vlast {T} {n:nat} (v : vector T (S n)) := last v.
-
-Definition vnth {T} {n:nat} (v : vector T n) (i:nat | i T :=
- fun i => vnth v i.
-
-Program Definition ConstVector {T} (n:nat) (c:T) : vector T n
- := of_list (repeat c n).
-Next Obligation.
- now rewrite repeat_length.
-Qed.
-
-Program Definition build_vector {T} {n:nat} (v:{n':nat | n' < n}%nat -> T) : vector T n
- := of_list (Vector.vector_to_list v).
-Next Obligation.
- apply Vector.vector_to_list_length.
-Qed.
-
-Lemma to_list_length {T} {n:nat} (v : vector T n) : length (to_list v) = n.
- induction v; simpl; trivial.
- now f_equal.
-Qed.
-
-Program Definition vcombine {T1 T2} {n:nat} (v1:vector T1 n) (v2:vector T2 n): vector (T1*T2) n :=
- of_list (combine (to_list v1) (to_list v2)).
-Next Obligation.
- rewrite combine_length.
- rewrite to_list_length, to_list_length.
- apply PeanoNat.Nat.min_id.
-Qed.
-
-Definition vector_zip {T1 T2} {n:nat} (v1:vector T1 n) (v2:vector T2 n): vector (T1*T2) n :=
- vcombine v1 v2.
-
-Definition vmap2 {A B C} {n} (f:A->B->C) (v1 : vector A n) (v2 : vector B n) : vector C n
- := Vector.map2 f v1 v2.
-
-Definition vmap4 {A B} {n} (f:A->A->A->A->B) (v1 v2 v3 v4 : vector A n) : vector B n :=
- vmap2 (fun '(a1,a2) '(a3,a4) => f a1 a2 a3 a4) (vcombine v1 v2) (vcombine v3 v4).
-
-Program Definition vectoro_to_ovector {T} {n} (v:vector (option T) n) : option (vector T n)
- := match listo_to_olist (to_list v) with
- | None => None
- | Some l => Some (of_list l)
- end.
-Next Obligation.
- symmetry in Heq_anonymous.
- apply listo_to_olist_some in Heq_anonymous.
- rewrite <- map_length with (f := Some).
- rewrite <- Heq_anonymous.
- now apply to_list_length.
-Qed.
-
-Definition vforall {A} {m:nat} (P: A -> Prop) (v:vector A m) : Prop
- := Vector.Forall P v.
-
-End Vector.
-
-Section Matrix.
-
-Definition matrix (T:Type) (n m : nat) := vector (vector T m) n.
-
-Definition mat_fun {T:Type} (n m : nat) (mat : matrix T n m ) :
- {n':nat | n' < n}%nat -> {m':nat | m' < m}%nat -> T :=
- fun i => fun j => vnth (vnth mat i) j.
-
-Definition mmap {A B} {n m} (f:A->B) (mat : matrix A n m) : matrix B n m :=
- vmap (vmap f) mat.
-
-Definition mnth {T} {n m :nat} (v : matrix T n m) (i:nat | i vcombine a b) (vcombine mat1 mat2).
-
-Definition matrix_zip {T1 T2} {n m : nat} (mat1 : matrix T1 n m) (mat2 : matrix T2 n m) : matrix (T1*T2) n m := mcombine mat1 mat2.
-
-Definition build_matrix {T} {n m:nat}
- (mat:{n':nat | n' < n}%nat -> {m':nat | m' < m}%nat -> T) : matrix T n m
- := vmap build_vector (build_vector mat).
-
-Definition transpose {T} {m n : nat} (mat:matrix T m n) : matrix T n m
- := build_matrix (fun i j => mnth mat j i).
-
-Definition ConstMatrix {T} (n m : nat) (c:T) : matrix T n m :=
- ConstVector n (ConstVector m c).
-
-Definition matrixo_to_omatrix {T} {m n} (v:matrix (option T) m n) : option (matrix T m n)
- := vectoro_to_ovector (vmap vectoro_to_ovector v).
-
-Definition mmap2 {A B C} {n m} (f:A->B->C) (v1 : matrix A n m) (v2 : matrix B n m) : matrix C n m := vmap2 (fun r1 r2 => vmap2 f r1 r2) v1 v2.
-
-Definition mforall {A} {m n:nat} (P: A -> Prop) (m:matrix A m n) : Prop
- := vforall (fun x => vforall P x) m.
-
-End Matrix.
-
-Section Tensor.
-Fixpoint tensor T (l:list nat) : Type
- := match l with
- | List.nil => T
- | x::l' => vector (tensor T l') x
- end.
-
-Lemma tensor0 T : tensor T List.nil = T.
-Proof.
- reflexivity.
-Qed.
-
-Lemma tensor1 T n : tensor T (n::List.nil) = vector T n.
-Proof.
- reflexivity.
-Qed.
-
-Lemma tensor_app T l1 l2 : tensor (tensor T l1) l2 = tensor T (l2++l1).
-Proof.
- revert l1.
- induction l2; intros l1; simpl; trivial.
- now rewrite IHl2.
-Qed.
-
-Fixpoint ConstTensor {T} (l : list nat) (c:T) : (tensor T l) :=
- match l with
- | List.nil => c
- | x::l' => ConstVector x (ConstTensor l' c)
- end.
-
-Fixpoint Tensor_map {A B} {dims:list nat} (f:A->B) : tensor A dims -> tensor B dims
- := match dims with
- | List.nil => fun x => f x
- | x::l' => vmap (Tensor_map f)
- end.
-
-Definition scalar {T} (c:T) : tensor T List.nil := c.
-
-End Tensor.
-
-Inductive NumericType
- := FloatType
- | IntType.
-
-
-Definition ntype_interp (n:NumericType) : Type
- := match n with
- | FloatType => nat
- | IntType => Z
- end.
-
- Structure BigArray (ln : list nat) (T : NumericType) : Type :=
- Tensor { tdata :> list (ntype_interp T); _ : length tdata = List.fold_right Nat.mul 1%nat ln }.
-
- Structure Array1 (n : nat) (T : NumericType) : Type :=
- array1 { a1data :> list (ntype_interp T); _ : length a1data = n}.
-
- Structure Array2 (n m : nat) (T : NumericType) : Type :=
- array2 { a2data :> list (ntype_interp T); _ : length a2data = n * m}.
-
-Definition tensor_abs_type (T:NumericType) (dims:list nat) := tensor (ntype_interp T) dims.
-
-Class TensorDef :=
- {
- tensor_t (T:NumericType) (dims:list nat) : Type
- ; tensor_repr {T:NumericType} {dims:list nat} : tensor_t T dims -> tensor_abs_type T dims -> Prop
-
- ; tensor_const {T} (dims : list nat) (c:ntype_interp T) : tensor_t T dims
- ; tensor_const_p {T} (dims : list nat) (c:ntype_interp T) : tensor_repr (tensor_const dims c) (ConstTensor dims c)
-
- ; tensor_map {A B} {dims : list nat} (f:ntype_interp A-> ntype_interp B) (t:tensor_t A dims) : tensor_t B dims
- ; tensor_map_p {A B} {dims : list nat} (f:ntype_interp A-> ntype_interp B) (t:tensor_t A dims) :
- forall r, tensor_repr t r ->
- tensor_repr (tensor_map f t) (Tensor_map f r)
-
- (* ; tensor_nth {A} {dims : list nat} (indices:list nat) (indices_in_range:True) (t:tensor A dims) : A *)
- (* ; tensor_nth_p {A:Type} {dims : list nat} (indices:list nat) (indices_in_range:True) (t:tensor_t A dims) : *)
- (* forall r, tensor_repr t r -> *)
- (* tensor_repr (tensor_nth indices indices_in_range t) (tensor_nth indices indices_in_range r) *)
- }.
-
-(*
-Class TensorDefExt {base:TensorDef} :=
- {
- tensor_transpose;
- }.
-*)
- (* ; tensor_nth {A} {dims : list nat} (indices:list nat) (indices_in_range:True) (t:tensor A dims) : A *)
- (* ; tensor_nth_p {A:Type} {dims : list nat} (indices:list nat) (indices_in_range:True) (t:tensor_t A dims) : *)
- (* forall r, tensor_repr t r -> *)
- (* tensor_repr (tensor_nth indices indices_in_range t) (tensor_nth indices indices_in_range r) *)
-
-(* Instance trivial_TensorDef : TensorDef := *)
-(* { *)
-(* tensor_t := tensor; *)
-(* tensor_repr _ _ a b := a = b *)
-(* }. *)
-(*
-Fixpoint flat_list_represent_tensor {T} {dims} (l:list A) (t:tensor T dims) : Prop
- :=
-
-Instance BigArray_TensorDef : TensorDef
- := {
- tensor_t A dims := list A;
- tensor_repr T dims (l:list A) (tensor T dims)
- := fix
- }.
-*)
-
-Require Import Floatish.
-Section float_ops.
- Context {floatish_impl:floatish}.
- Local Open Scope float.
-
- Definition vsum {m:nat} (v:vector float m) : float
- := List.fold_right Fplus 0 (to_list v).
-
- Definition msum {m n:nat} (v:matrix float m n) : float :=
- vsum (vmap vsum v).
-
- Definition vdot {m:nat} (v1 v2 : vector float m) : float :=
- List.fold_right Fplus 0
- (List.map (fun '(a,b) => a * b)
- (combine (to_list v1) (to_list v2))).
-
- Definition vadd {m:nat} (v1 v2 : vector float m) :=
- vmap (fun '(a,b) => a+b) (vcombine v1 v2).
-
- Definition madd {m n:nat} (mat1 mat2 : matrix float m n) :=
- mmap (fun '(a,b) => a+b) (mcombine mat1 mat2).
-
- Definition matrix_vector_mult {n m} (l : matrix float n m)(r : vector float m) : vector float n :=
- vmap (fun l1 => vdot l1 r) l.
-
- Definition matrix_vector_add {n m} (l : matrix float n m) (r : vector float n) : matrix float n m :=
- build_matrix (fun i j => (vnth (vnth l i) j) + (vnth r i)).
-
-(*
- transpose (vmap (fun l1 => vadd l1 r) (transpose l)).
- *)
-
- Definition matrix_mult {n m p} (l : matrix float n m)(r : matrix float m p) : matrix float n p :=
- build_matrix (fun i k => vsum (build_vector
- (fun j => (vnth (vnth l i) j) *
- (vnth (vnth r j) k)))).
-
-(*
- transpose (vmap (fun r1 => matrix_vector_mult l r1) (transpose r)).
-*)
-
-End float_ops.
-
-
-
-
-
-
diff --git a/ocaml/Makefile b/ocaml/Makefile
deleted file mode 100644
index 18179501..00000000
--- a/ocaml/Makefile
+++ /dev/null
@@ -1,78 +0,0 @@
-#
-# Copyright 2015-2016 IBM Corporation
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-# User-level configuration
-# include ../Makefile.config
-# Contains the list of all the Coq modules
-include ../Makefile.coq_modules
-
-## Configuraton
-NNOPT_HOME=$(CURDIR)/..
-
-############# Shouldn't have to be changed after this
-OCAMLBUILD= ocamlbuild \
- -no-links -classic-display \
- -tags annot -use-ocamlfind -package unix -package base64 -package csv
-
-MENHIRFLAG=-use-menhir
-#MENHIRFLAG=
-
-## Mains
-MAIN=nnopt
-
-TARGET=native
-
-## Toplevel
-all: ../bin/$(MAIN)
-
-native: ../bin/$(MAIN)
-
-## Extraction
-VO_FILES = $(MODULES:%=../coq/%.vo)
-
-extracted: extracted/StaticConfig.ml extracted/NnoptExtracted.ml extracted/NnoptExtracted.mli
-
-extracted/StaticConfig.ml extracted/NnoptExtracted.ml extracted/NnoptExtracted.mli : $(VO_FILES) NnoptExtraction.v
- rm -rf extracted
- mkdir -p extracted
- echo "(* This file is generated *)" > extracted/StaticConfig.ml
- echo "let nnopt_home = \"$(NNOPT_HOME)\"" >> extracted/StaticConfig.ml
- (coqc -R ../coq FormalML NnoptExtraction.v)
-
-## Native
-../bin/$(MAIN): extracted $(MAIN).$(TARGET) ../bin
- cp _build/$(MAIN).$(TARGET) ../bin/$(MAIN)
-
-../bin:
- @mkdir -p ../bin
-
-$(MAIN).$(TARGET): extracted
- $(OCAMLBUILD) $(MENHIRFLAG) -Is extracted -Is src $(MAIN).$(TARGET)
-
-## Clean
-
-clean:
- ocamlbuild -clean -no-log
- rm -rf _build
- rm -f ../bin/$(MAIN)
-
-cleanall: clean
- rm -f NnoptExtraction.glob NnoptExtraction.vo .NnoptExtraction.aux
- rm -rf extracted
- rm -rf *~
-
-.NOTPARALLEL:
-
diff --git a/ocaml/NnoptExtraction.v b/ocaml/NnoptExtraction.v
deleted file mode 100644
index 08a3073b..00000000
--- a/ocaml/NnoptExtraction.v
+++ /dev/null
@@ -1,24 +0,0 @@
-(* Configuration of the extraction *)
-
-Require Extraction.
-Extraction Language OCaml.
-Require Import ExtrOcamlBasic ExtrOcamlString ExtrOcamlZInt ExtrOcamlNatInt.
-Extraction Blacklist String List.
-
-Require Import FloatishIEEE.
-Require Import ExtrFloatishIEEE.
-
-(* Require Import ExtrR. *)
-(* Our stuff modules *)
-
-Require API.
-
-(* Workaround for https://github.com/coq/coq/issues/13288 , suggested by a comment on the issue.
- Coq extraction currently creates a clash between the extracted Decimal.int and the
- ocaml int type.
-*)
-Extract Inductive Decimal.int => unit [ "(fun _ -> ())" "(fun _ -> ())" ] "(fun _ _ _ -> assert false)".
-
-Cd "./extracted".
-
-Recursive Extraction Library API.
diff --git a/ocaml/nnopt.ml b/ocaml/nnopt.ml
deleted file mode 100644
index e2f1cbd6..00000000
--- a/ocaml/nnopt.ml
+++ /dev/null
@@ -1,74 +0,0 @@
-open Format
-
-open API
-open Util
-open Pretty
-
-
-let () = Format.printf "Result of running opt: %a\n" (pretty_visible_option pretty_df_env) API.opt ;;
-let () = Format.printf "Result of running opt2: %a\n" (pretty_visible_option pretty_df_env) API.opt2 ;;
-let () = Format.printf "The testopt environment: %a\n" pretty_df_env API.testopt ;;
-let () = Format.printf "The testreeopt environment: %a\n" pretty_df_env API.testreeopt ;;
-
-let () = Format.printf "The gradenv environment: %a\n" pretty_df_env API.gradenv ;;
-let () = Format.printf "The gradenv_tree environment: %a\n" pretty_df_env API.gradenv_tree ;;
-
-let () = Format.printf "The test_update environment: %a\n" pretty_df_env API.test_update ;;
-
-let () = Format.printf "The test environment: %a\n" pretty_df_env API.test_env ;;
-
-let data = read_int_matrix_from_csv "breast-cancer-wisconsin.data" ;;
-let actual_data = API.discard_first data ;;
-
-let () = Format.printf "first part of data without the first column: %d\n" (List.hd (List.hd actual_data))
-let normalized_data = API.normalizeIntData actual_data ;;
-let () = Format.printf "first 10 rows of normalized data without the first column: \n%a\n" ( pretty_matrix 10 10) normalized_data ;;
-
-let () = Random.self_init()
-
-let randomStream = mkIndexedStream 0 (Obj.magic (API.random_float_vector ())) ;;
-let fvals = fst(streamtake 5 randomStream) ;;
-let () = Format.printf "random list : %a\n" (pretty_blist pp_print_float) fvals ;;
-
-let init_env = init_env2 9 15 1 (char_list_of_string "w") (char_list_of_string "b")
- (Obj.magic (random_float_matrix ())) (Obj.magic (random_float_matrix ())) ;;
-let () = Format.printf "Init environment: %a\n" pretty_df_env init_env ;;
-
-let wval = eval_wisconsin_batch 10 (Obj.magic init_env) (Obj.magic normalized_data) ;;
-let () = Format.printf "wisconsin init loss value : %a\n" (pretty_blist pp_print_float) (Obj.magic wval) ;;
-
-let wval2 = wisconsin_test 10 100 (Obj.magic init_env) (Obj.magic normalized_data) ;;
-let () = Format.printf "wisconsin loss value : %a\n" (pretty_blist pp_print_float) (Obj.magic wval2) ;;
-(*
-let wenv = wisconsin_test_env 6 10 (Obj.magic init_env) (Obj.magic normalized_data) ;;
-let () = Format.printf "wisconsin test env: %a\n" pretty_df_env wenv ;;
-
-let nnval = nn_test_val ;;
-let () = Format.printf "nn test init loss value : %a\n" (pretty_blist pp_print_float) (Obj.magic nnval) ;;
-
-let wval3 = nn_test 1 ;;
-let () = Format.printf "NN test loss value : %a\n" (pretty_blist pp_print_float) (Obj.magic wval3) ;;
-
-let wenv3 = nn_test_env 1 ;;
-let () = Format.printf "NN test env: %a\n" pretty_df_env wenv3 ;;
-
-let wenv4 = nn_test_gradenv ;;
-let () = Format.printf "NN gradenv env: %a\n" pretty_df_env wenv4 ;;
-
-let wenv5 = nn_test_gradenv_tree ;;
-let () = Format.printf "NN gradenv env tree: %a\n" pretty_df_env wenv5 ;;
-*)
-(*
-let gradenvtree = wisconsin_gradenv_tree 1 (Obj.magic init_env) (Obj.magic normalized_data) ;;
-let () = Format.printf "wisconsin gradenv_tree : %a\n" pretty_df_env gradenvtree ;;
-
-let gradenv = wisconsin_gradenv 1 (Obj.magic init_env) (Obj.magic normalized_data) ;;
-let () = Format.printf "wisconsin gradenv : %a\n" pretty_df_env gradenv ;;
-*)
-
-
-
-
-
-
-
diff --git a/ocaml/src/Pretty.ml b/ocaml/src/Pretty.ml
deleted file mode 100644
index 8259be20..00000000
--- a/ocaml/src/Pretty.ml
+++ /dev/null
@@ -1,70 +0,0 @@
-open Format
-
-open API
-
-let rec subVar_to_list sv =
- begin match sv with
- | Name s -> (Util.string_of_char_list s, [])
- | Sub (v,i) -> let (s,r) = subVar_to_list v in
- (s, i::r)
- end
-
-let pretty_const_string s ff _ = pp_print_string ff s
-
-let pretty_blist ?(bstart="[") ?(bend="]") ?(bsep=",") pp ff l =
- pp_print_string ff bstart ;
- (pp_print_list ~pp_sep:(pretty_const_string bsep)) pp ff l ;
- pp_print_string ff bend
-
-let pretty_subVar ff sv =
- let (s,l) = subVar_to_list sv in
- pp_print_string ff s ;
- if l <> []
- then pretty_blist pp_print_int ff l
-
-let pretty_definition_function_types ff dft =
- begin match dft with
- | DTfloat -> fprintf ff "%s" "float"
- | DTVector m -> fprintf ff "%s[%d]" "float" m
- | DTMatrix (m,n) -> fprintf ff "%s[%d,%d]" "float" m n
- end
-
-let pretty_var_type ff (sv, dft) =
- fprintf ff "%a{%a}" pretty_subVar sv pretty_definition_function_types dft
-
-let pretty_vector n ff v =
- let fs = List.init n (fun i -> Obj.magic (v i)) in
- pretty_blist pp_print_float ff fs
-
-let pretty_matrix m n ff v =
- let fs = List.init m (fun i -> List.init n (fun j -> Obj.magic (v i j))) in
- pretty_blist (pretty_blist pp_print_float) ff fs
-
-let pretty_definition_function_types_interp ff dft value =
- begin match dft with
- | DTfloat -> pp_print_float ff (Obj.magic value)
- | DTVector m -> pretty_vector m ff (Obj.magic value)
- | DTMatrix (m,n) -> pretty_matrix m n ff (Obj.magic value)
- end
-
-let pretty_env_entry_type ff (ExistT ((sv, dft), value)) =
- pretty_subVar ff sv ;
- pp_print_string ff "->" ;
- pretty_definition_function_types_interp ff dft value
-
-let pretty_df_env ff env =
- pretty_blist ~bstart:"{" ~bend:"}" pretty_env_entry_type ff env
-
-(* This should be replaced with Format.pp_print_option from ocaml >=4.08 *)
-let pretty_option ?none some formatter value =
- begin match value with
- | None ->
- begin
- match none with
- | None -> ()
- | Some none -> none formatter ()
- end
- | Some value -> some formatter value
- end
-
-let pretty_visible_option some formatter value = pretty_option ~none:(pretty_const_string "None") some formatter value
diff --git a/ocaml/src/Pretty.mli b/ocaml/src/Pretty.mli
deleted file mode 100644
index 2dc36726..00000000
--- a/ocaml/src/Pretty.mli
+++ /dev/null
@@ -1,19 +0,0 @@
-open API
-open DefinedFunctions
-open Vector
-
-val pretty_subVar : Format.formatter -> coq_SubVar -> unit
-
-val pretty_definition_function_types : Format.formatter -> definition_function_types -> unit
-
-val pretty_vector : int -> Format.formatter -> float coq_Vector -> unit
-val pretty_matrix : int -> int -> Format.formatter -> float coq_Matrix -> unit
-
-val pretty_var_type : Format.formatter -> var_type -> unit
-
-val pretty_env_entry_type : Format.formatter -> env_entry_type -> unit
-val pretty_df_env : Format.formatter -> df_env -> unit
-
-val pretty_visible_option : (Format.formatter -> 'a -> unit) -> Format.formatter -> 'a option -> unit
-
-val pretty_blist : ?bstart:string -> ?bend:string -> ?bsep:string -> (Format.formatter -> 'a -> unit) -> Format.formatter -> 'a list -> unit
diff --git a/ocaml/src/Util.ml b/ocaml/src/Util.ml
deleted file mode 100644
index 06252aa5..00000000
--- a/ocaml/src/Util.ml
+++ /dev/null
@@ -1,28 +0,0 @@
-let string_of_char_list l =
- let b = Bytes.create (List.length l) in
- let i = ref 0 in
- List.iter (fun c -> Bytes.set b !i c; incr i) l;
- Bytes.to_string b
-
-let char_list_of_string s =
- let l = ref [] in
- String.iter (fun c -> l := c :: !l) s;
- List.rev !l
-
-let read_int_matrix_from_csv name =
- let sdata = Csv.load name in
- List.map (List.map int_of_string) sdata
-
-let rec memoized_vector f =
- let cache = Hashtbl.create 10 in
- begin fun n ->
- try Hashtbl.find cache n
- with Not_found -> begin
- let x = f n in
- Hashtbl.add cache n x; x
- end
- end
-
-let random_float_vector () = memoized_vector (fun _ -> Random.float 1.0)
-let random_float_matrix () = memoized_vector (fun _ -> random_float_vector ())
-
diff --git a/ocaml/src/Util.mli b/ocaml/src/Util.mli
deleted file mode 100644
index 6d31874f..00000000
--- a/ocaml/src/Util.mli
+++ /dev/null
@@ -1,9 +0,0 @@
-val string_of_char_list : char list -> string
-val char_list_of_string : string -> char list
-
-val read_int_matrix_from_csv : string -> int list list
-
-val memoized_vector : (int -> 'a) -> int -> 'a
-
-val random_float_vector : unit -> int -> float
-val random_float_matrix : unit -> int -> int -> float
diff --git a/coq/CertRL/LM/README.md b/rocq/CertRL/LM/README.md
similarity index 100%
rename from coq/CertRL/LM/README.md
rename to rocq/CertRL/LM/README.md
diff --git a/coq/CertRL/LM/R_compl.v b/rocq/CertRL/LM/R_compl.v
similarity index 98%
rename from coq/CertRL/LM/R_compl.v
rename to rocq/CertRL/LM/R_compl.v
index 04d53c24..a60be328 100644
--- a/coq/CertRL/LM/R_compl.v
+++ b/rocq/CertRL/LM/R_compl.v
@@ -116,9 +116,9 @@ generalize (archimed (ln d / ln k)); intros (Y1,_).
rewrite Rmult_comm.
apply Rlt_le_trans with (1:=Y1).
generalize (up (ln d / ln k)); clear; intros x.
-rewrite INR_IZR_INZ, Zabs2Nat.id_abs.
+rewrite INR_IZR_INZ, Zabs.inj_Zabs_nat.
apply IZR_le.
-case (Zabs_spec x); intros (T1,T2); rewrite T2; auto with zarith.
+case (Zabs.Zabs_spec x); intros (T1,T2); rewrite T2; auto with zarith.
(* k = 0 *)
exists 1%nat.
rewrite <- Hk';simpl.
diff --git a/coq/CertRL/LM/check_sub_structure.v b/rocq/CertRL/LM/check_sub_structure.v
similarity index 87%
rename from coq/CertRL/LM/check_sub_structure.v
rename to rocq/CertRL/LM/check_sub_structure.v
index d3c2480b..b4a3ea6e 100644
--- a/coq/CertRL/LM/check_sub_structure.v
+++ b/rocq/CertRL/LM/check_sub_structure.v
@@ -29,7 +29,7 @@ Context {CCP: compatible_g P}.
Record Sg:= mk_Sg {
val :> G ;
H: P val
-}.
+ }.
Lemma Sg_eq: forall (x y:Sg), (val x = val y) -> x = y.
Proof.
@@ -45,11 +45,11 @@ Qed.
Definition Sg_zero : Sg := mk_Sg zero (compatible_g_zero P CCP).
Definition Sg_plus (x y : Sg) : Sg :=
- mk_Sg (plus x y)
+ mk_Sg (plus (val x) (val y))
(compatible_g_plus P (val x) (val y) CCP (H x) (H y)).
Definition Sg_opp (x : Sg) : Sg :=
- mk_Sg (opp x)
+ mk_Sg (opp (val x))
(compatible_g_opp P (val x) CCP (H x)).
Lemma Sg_plus_comm: forall (x y:Sg), Sg_plus x y = Sg_plus y x.
@@ -81,12 +81,18 @@ unfold Sg_plus; simpl.
apply plus_opp_r.
Qed.
+Definition Sg_AbelianMonoid_mixin :=
+ AbelianMonoid.Mixin Sg Sg_plus Sg_zero Sg_plus_comm
+ Sg_plus_assoc Sg_plus_zero_r.
+
+Canonical Sg_AbelianMonoid :=
+ AbelianMonoid.Pack Sg (Sg_AbelianMonoid_mixin) Sg.
+
Definition Sg_AbelianGroup_mixin :=
- AbelianGroup.Mixin Sg Sg_plus Sg_opp Sg_zero Sg_plus_comm
- Sg_plus_assoc Sg_plus_zero_r Sg_plus_opp_r.
+ AbelianGroup.Mixin Sg_AbelianMonoid Sg_opp Sg_plus_opp_r.
Canonical Sg_AbelianGroup :=
- AbelianGroup.Pack Sg (Sg_AbelianGroup_mixin) Sg.
+ AbelianGroup.Pack Sg (AbelianGroup.Class _ (Sg_AbelianMonoid_mixin) Sg_AbelianGroup_mixin) Sg.
End Subgroups.
@@ -135,12 +141,18 @@ unfold Sg_plus; unfold Sg_scal; simpl.
apply scal_distr_r.
Qed.
+Definition Sg_MAbelianMonoid_mixin :=
+ AbelianMonoid.Mixin (@Sg _ P) Sg_Mplus Sg_zero Sg_plus_comm
+ Sg_plus_assoc Sg_plus_zero_r.
+
+Canonical Sg_MAbelianMonoid :=
+ AbelianMonoid.Pack(@Sg _ P) (Sg_MAbelianMonoid_mixin) (@Sg _ P).
+
Definition Sg_MAbelianGroup_mixin :=
- AbelianGroup.Mixin Sg Sg_Mplus Sg_opp Sg_zero Sg_plus_comm
- Sg_plus_assoc Sg_plus_zero_r Sg_plus_opp_r.
+ AbelianGroup.Mixin Sg_MAbelianMonoid Sg_opp Sg_plus_opp_r.
Canonical Sg_MAbelianGroup :=
- AbelianGroup.Pack Sg (Sg_MAbelianGroup_mixin) (@Sg _ P).
+ AbelianGroup.Pack (@Sg _ P) (AbelianGroup.Class _ (Sg_MAbelianMonoid_mixin) Sg_MAbelianGroup_mixin) (@Sg _ P).
Definition Sg_ModuleSpace_mixin :=
ModuleSpace.Mixin R_Ring (Sg_MAbelianGroup)
diff --git a/coq/CertRL/LM/compatible.v b/rocq/CertRL/LM/compatible.v
similarity index 82%
rename from coq/CertRL/LM/compatible.v
rename to rocq/CertRL/LM/compatible.v
index 40361961..573926d6 100644
--- a/coq/CertRL/LM/compatible.v
+++ b/rocq/CertRL/LM/compatible.v
@@ -133,10 +133,11 @@ intros. split.
unfold compatible_m in *.
unfold compatible_g in *.
assert (u = opp u').
+replace (@zero (ModuleSpace.AbelianMonoid R_Ring E)) with (@zero (AbelianGroup.AbelianMonoid (ModuleSpace.AbelianGroup R_Ring E))) in H2 by reflexivity.
rewrite <- plus_opp_r with u' in H2.
rewrite plus_comm in H2.
-apply plus_reg_l in H2.
-trivial.
+eapply plus_reg_l.
+eapply H2.
assert (phi' u).
rewrite H3 in H2.
rewrite H3.
@@ -144,8 +145,9 @@ rewrite <- scal_opp_one.
apply (proj2 Cphi'). trivial.
apply H; trivial.
assert (u' = opp u).
+replace (@zero (ModuleSpace.AbelianMonoid R_Ring E)) with (@zero (AbelianGroup.AbelianMonoid (ModuleSpace.AbelianGroup R_Ring E))) in H2 by reflexivity.
rewrite <- plus_opp_r with u in H2.
-apply plus_reg_l in H2. trivial.
+eapply plus_reg_l; eauto.
assert (phi u').
rewrite H3 in H2.
rewrite H3.
@@ -156,7 +158,7 @@ Qed.
Lemma plus_u_opp_v : forall (u v : E), u = v <-> (plus u (opp v) = zero).
intros; split.
-+ intros. rewrite H. rewrite plus_opp_r. reflexivity.
++ intros. rewrite H. now generalize (plus_opp_r v).
+ intros. apply plus_reg_r with (opp v). rewrite plus_opp_r; trivial.
Qed.
@@ -179,7 +181,9 @@ destruct Cphi as ((Cphi1,(x,Cphi2)),Cphi3).
destruct Cphi' as ((Cphi'1,(x',Cphi'2)),Cphi'3).
assert (plus (plus u (opp v)) (plus u' (opp v')) = zero).
rewrite plus_assoc_gen. rewrite H4.
-rewrite plus_assoc_gen. rewrite plus_opp_r. rewrite plus_opp_r.
+rewrite plus_assoc_gen.
+replace (plus v' (opp v')) with (@zero (ModuleSpace.AbelianMonoid R_Ring E)) by (symmetry; generalize (plus_opp_r v'); intros HH; apply HH).
+replace (plus v (opp v)) with (@zero (ModuleSpace.AbelianMonoid R_Ring E)) by (symmetry; generalize (plus_opp_r v); intros HH; apply HH).
rewrite plus_zero_r. reflexivity.
rewrite plus_u_opp_v.
rewrite (plus_u_opp_v u' v').
@@ -203,9 +207,18 @@ apply (Cphi'3 x (opp one)) in H1.
assert ((x = zero) /\ (opp x = zero)).
apply H. trivial. rewrite <- (scal_zero_l y). trivial.
rewrite <- scal_opp_one. trivial.
+simpl.
+replace (@zero
+ (AbelianMonoid.Pack (ModuleSpace.sort R_Ring E)
+ (AbelianGroup.base (ModuleSpace.sort R_Ring E)
+ (ModuleSpace.base R_Ring
+ match E return Type with
+ | @ModuleSpace.Pack _ T _ _ => T
+ end (ModuleSpace.class R_Ring E)))
+ (ModuleSpace.sort R_Ring E))) with (@zero (ModuleSpace.AbelianMonoid R_Ring E)) by reflexivity.
rewrite <- (scal_zero_l y'). trivial.
-rewrite plus_opp_r.
-rewrite plus_zero_l. reflexivity.
+rewrite plus_zero_l.
+apply (plus_opp_r x).
intuition.
Qed.
diff --git a/coq/CertRL/LM/continuous_linear_map.v b/rocq/CertRL/LM/continuous_linear_map.v
similarity index 99%
rename from coq/CertRL/LM/continuous_linear_map.v
rename to rocq/CertRL/LM/continuous_linear_map.v
index cd56c54b..a25ec153 100644
--- a/coq/CertRL/LM/continuous_linear_map.v
+++ b/rocq/CertRL/LM/continuous_linear_map.v
@@ -56,7 +56,7 @@ specialize (H2 (norm x)).
assert (norm x = 0).
case (Req_dec (norm x) 0); trivial.
intros H3.
-elimtype False.
+exfalso.
apply H2.
exists x.
split; trivial.
@@ -93,7 +93,7 @@ Proof.
intros A H1 H2 H3.
case H2; trivial.
intros H4.
-elimtype False.
+exfalso.
specialize (Is_only_zero_set_correct3 H1).
intros H5.
assert (~(exists x : E, x <> zero)).
@@ -104,7 +104,7 @@ intros x.
case (Req_dec (norm x) 0).
apply norm_eq_zero.
intros H6.
-elimtype False.
+exfalso.
apply H.
exists x.
intros H7; apply H6.
@@ -236,7 +236,6 @@ Proof.
intros f.
generalize (operator_norm_ge_0 f).
destruct (operator_norm f); try easy.
-intros _; apply Rle_refl; simpl.
Qed.
Lemma operator_norm_helper:
@@ -889,12 +888,19 @@ unfold clm_plus, clm_opp; simpl.
apply: plus_opp_r.
Qed.
+
+Definition clm_AbelianMonoid_mixin :=
+ AbelianMonoid.Mixin clm clm_plus clm_zero clm_plus_comm
+ clm_plus_assoc clm_plus_zero_r.
+
+Canonical clm_AbelianMonoid :=
+ AbelianMonoid.Pack clm (clm_AbelianMonoid_mixin) clm.
+
Definition clm_AbelianGroup_mixin :=
- AbelianGroup.Mixin clm clm_plus clm_opp clm_zero clm_plus_comm
- clm_plus_assoc clm_plus_zero_r clm_plus_opp_r.
+ AbelianGroup.Mixin clm_AbelianMonoid clm_opp clm_plus_opp_r.
Canonical clm_AbelianGroup :=
- AbelianGroup.Pack clm (clm_AbelianGroup_mixin) clm.
+ AbelianGroup.Pack clm (AbelianGroup.Class _ (clm_AbelianMonoid_mixin) clm_AbelianGroup_mixin) clm.
(** Clm is ModuleSpace *)
@@ -1308,7 +1314,7 @@ specialize (H2 (norm x)).
assert (norm x = 0).
case (Req_dec (norm x) 0); trivial.
intros H3.
-elimtype False.
+exfalso.
apply H2.
exists x.
split; trivial.
@@ -1391,7 +1397,7 @@ intros A H1 H2 H3.
case H2; trivial.
intros H4.
apply H3.
-elimtype False.
+exfalso.
apply H1.
apply Is_only_zero_set_correct2''_phi.
intros x.
@@ -1460,7 +1466,6 @@ Proof.
intros f.
generalize (operator_norm_ge_0_phi f).
destruct (operator_norm_phi f); try easy.
-intros _; apply Rle_refl; simpl.
Qed.
Lemma operator_norm_helper'_phi:
diff --git a/coq/CertRL/LM/fixed_point.v b/rocq/CertRL/LM/fixed_point.v
similarity index 98%
rename from coq/CertRL/LM/fixed_point.v
rename to rocq/CertRL/LM/fixed_point.v
index 7f4bd668..228c7d51 100644
--- a/coq/CertRL/LM/fixed_point.v
+++ b/rocq/CertRL/LM/fixed_point.v
@@ -105,7 +105,7 @@ Proof.
intros f a k p m D H1 H2 H3 H4.
case_eq m.
(* *)
-intros _; rewrite plus_0_r.
+intros _; rewrite Nat.add_0_r.
assert (L:(0 < k ^ p / (1 - k) * D)).
apply Rmult_lt_0_compat; trivial.
apply Rdiv_lt_0_compat.
@@ -184,7 +184,8 @@ Proof.
intros a k p m n D (H1,H1') H2 Phi_a H3 H4 Hp Hm Hn.
case H1; intros P.
(* *)
-case (le_or_lt p m); intros H5.
+
+case (Nat.le_gt_cases p m); intros H5.
(* . *)
replace m with (p+(m-p))%nat.
apply ball_le with (k^p/(1-k) *D).
@@ -197,7 +198,7 @@ apply Rle_pow_le; try assumption.
split; try left; try assumption.
apply dist_iter; try assumption.
split; assumption.
-now apply le_plus_minus_r.
+auto with arith.
(* . *)
apply ball_sym.
replace p with (m+(p-m))%nat.
@@ -211,7 +212,7 @@ apply Rle_pow_le; try assumption.
split; try left; assumption.
apply dist_iter; try assumption.
split; assumption.
-now apply le_plus_minus_r, lt_le_weak.
+auto with arith.
(* *)
apply ball_le with 0.
rewrite <- P.
@@ -219,8 +220,8 @@ rewrite pow_i; try assumption.
right; unfold Rdiv; ring.
apply dist_iter_aux_zero; try assumption.
now rewrite P.
-now apply lt_le_trans with n.
-now apply lt_le_trans with n.
+now apply Nat.lt_le_trans with n.
+now apply Nat.lt_le_trans with n.
Qed.
End iter_dist.
@@ -403,11 +404,11 @@ intros P Q (N1,H1) (N2,H2).
exists (max N1 N2).
intros n Hn; split.
apply H1.
-apply le_trans with (2:=Hn).
-apply Max.le_max_l.
+apply Nat.le_trans with (2:=Hn).
+apply Nat.le_max_l.
apply H2.
-apply le_trans with (2:=Hn).
-apply Max.le_max_r.
+apply Nat.le_trans with (2:=Hn).
+apply Nat.le_max_r.
intros P Q H (N,HN).
exists N.
intros n Hn.
@@ -746,7 +747,7 @@ Proof.
intros f P a k p m D H1 H2 HH H3 H4.
case_eq m.
(* *)
-intros _; rewrite plus_0_r.
+intros _; rewrite Nat.add_0_r.
assert (L:(0 < k ^ p / (1 - k) * D)).
apply Rmult_lt_0_compat; trivial.
apply Rdiv_lt_0_compat.
@@ -825,7 +826,7 @@ Proof.
intros a k p m n D (H1,H1') H2 Phi_a H3 H4 Hp Hm Hn.
case H1; intros P0.
(* *)
-case (le_or_lt p m); intros H5.
+case (Nat.le_gt_cases p m); intros H5.
(* . *)
replace m with (p+(m-p))%nat.
apply ball_le with (k^p/(1-k) *D).
@@ -842,7 +843,7 @@ intros p0.
apply phi_iter_f.
trivial.
trivial.
-now apply le_plus_minus_r.
+auto with arith.
(* . *)
apply ball_sym.
replace p with (m+(p-m))%nat.
@@ -858,7 +859,7 @@ apply dist_iter_phi with phi; try assumption.
split; assumption.
intros p0.
apply phi_iter_f; trivial.
-now apply le_plus_minus_r, lt_le_weak.
+auto with arith.
(* *)
apply ball_le with 0.
rewrite <- P0.
@@ -866,8 +867,8 @@ rewrite pow_i; try assumption.
right; unfold Rdiv; ring.
apply dist_iter_aux_zero_phi; try assumption.
now rewrite P0.
-now apply lt_le_trans with n.
-now apply lt_le_trans with n.
+now apply Nat.lt_le_trans with n.
+now apply Nat.lt_le_trans with n.
Qed.
End iter_dist_sub.
@@ -979,11 +980,11 @@ intros P Q (N1,H1) (N2,H2).
exists (max N1 N2).
intros n Hn; split.
apply H1.
-apply le_trans with (2:=Hn).
-apply Max.le_max_l.
+apply Nat.le_trans with (2:=Hn).
+apply Nat.le_max_l.
apply H2.
-apply le_trans with (2:=Hn).
-apply Max.le_max_r.
+apply Nat.le_trans with (2:=Hn).
+apply Nat.le_max_r.
intros P Q H (N,HN).
exists N.
intros n Hn.
diff --git a/coq/CertRL/LM/hierarchyD.v b/rocq/CertRL/LM/hierarchyD.v
similarity index 99%
rename from coq/CertRL/LM/hierarchyD.v
rename to rocq/CertRL/LM/hierarchyD.v
index 371ebff9..0575cb48 100644
--- a/coq/CertRL/LM/hierarchyD.v
+++ b/rocq/CertRL/LM/hierarchyD.v
@@ -416,7 +416,7 @@ Lemma clm_lim_aux2: forall (Fi : ((clm E F) -> Prop) -> Prop),
Proof.
intros Fi FFi HFc t eps.
case (is_zero_dec t); intros Ht.
-exists zero.
+exists (@zero F).
unfold Fri, Fr, cleanFilter.
exists (fun _ => True).
split.
@@ -461,7 +461,7 @@ Definition clm_lim (Fi : ((clm E F) -> Prop) -> Prop)
lim (Fri Fi t) (clm_lim_aux1 Fi H1 t) (clm_lim_aux2 Fi H1 H2 t).
Lemma ball_plus_plus: forall (x a y b:F) (eps:posreal), ball x eps a -> ball y eps b ->
- ball (plus x y) (2*(@norm_factor R_AbsRing F)*eps) (plus a b).
+ ball (@plus F x y) (2*(@norm_factor R_AbsRing F)*eps) (@plus F a b).
Proof.
intros x a y b eps Hx Hy.
pose (v:=@norm_factor R_AbsRing F); fold v.
@@ -642,7 +642,7 @@ apply Rle_trans with
(minus (g x) (f x)))).
right; f_equal.
unfold minus; rewrite <- plus_assoc.
-apply trans_eq with (plus (clm_lim Fi H1 H2 x) (opp (f x))).
+apply trans_eq with (@plus F (clm_lim Fi H1 H2 x) (opp (f x))).
reflexivity.
f_equal.
rewrite plus_assoc.
diff --git a/coq/CertRL/LM/hilbert.v b/rocq/CertRL/LM/hilbert.v
similarity index 95%
rename from coq/CertRL/LM/hilbert.v
rename to rocq/CertRL/LM/hilbert.v
index 42cba33f..55b0a345 100644
--- a/coq/CertRL/LM/hilbert.v
+++ b/rocq/CertRL/LM/hilbert.v
@@ -79,7 +79,7 @@ Module Exports.
Coercion base : class_of >-> ModuleSpace.class_of.
Coercion mixin : class_of >-> mixin_of.
Coercion sort : type >-> Sortclass.
-Coercion AbelianGroup : type >-> AbelianGroup.type.
+Global Coercion AbelianGroup : type >-> AbelianGroup.type.
Canonical AbelianGroup.
Coercion ModuleSpace : type >-> ModuleSpace.type.
Canonical ModuleSpace.
@@ -133,7 +133,7 @@ Proof.
apply PreHilbert.ax2.
Qed.
-Lemma inner_eq_0 : forall x, inner x x = 0 -> x = zero.
+Lemma inner_eq_0 : forall x, inner x x = 0 -> x = (@zero E).
Proof.
apply PreHilbert.ax3.
Qed.
@@ -143,7 +143,7 @@ Proof.
intros x; apply sqrt_pos.
Qed.
-Lemma norm_eq_0: forall x, Hnorm x = 0 -> x = zero.
+Lemma norm_eq_0: forall x, Hnorm x = 0 -> x = (@zero E).
Proof.
intros x; unfold norm; intros H.
assert (inner x x = 0).
@@ -163,7 +163,7 @@ intros x y l.
now rewrite inner_sym inner_scal_l inner_sym.
Qed.
-Lemma inner_zero_l: forall x, inner zero x = 0.
+Lemma inner_zero_l: forall x, inner (@zero E) x = 0.
Proof.
intros x.
apply trans_eq with (inner (scal 0 zero) x).
@@ -172,18 +172,18 @@ rewrite inner_scal_l.
apply Rmult_0_l.
Qed.
-Lemma inner_zero_r: forall x, inner x zero = 0.
+Lemma inner_zero_r: forall x, inner x (@zero E) = 0.
Proof.
intros x.
rewrite inner_sym; apply inner_zero_l.
Qed.
-Lemma inner_plus_l : forall (x y z : E), inner (plus x y) z = inner x z + inner y z.
+Lemma inner_plus_l : forall (x y z : E), inner (@plus E x y) z = inner x z + inner y z.
Proof.
apply PreHilbert.ax5.
Qed.
-Lemma inner_plus_r : forall (x y z : E), inner x (plus y z) = inner x y + inner x z.
+Lemma inner_plus_r : forall (x y z : E), inner x (@plus E y z) = inner x y + inner x z.
Proof.
intros x y z.
now rewrite inner_sym inner_plus_l 2!(inner_sym x).
@@ -202,7 +202,7 @@ intros x y.
now rewrite inner_sym inner_opp_r inner_sym.
Qed.
-Lemma norm_zero: Hnorm zero = 0.
+Lemma norm_zero: Hnorm (@zero E) = 0.
Proof.
unfold Hnorm; now rewrite inner_zero_l sqrt_0.
Qed.
@@ -236,7 +236,7 @@ apply inner_ge_0.
Qed.
Lemma square_expansion_plus: forall x y,
- inner (plus x y) (plus x y) = inner x x + 2 * inner x y + inner y y.
+ inner (@plus E x y) (plus x y) = inner x x + 2 * inner x y + inner y y.
Proof.
intros x y.
rewrite inner_plus_l 2!inner_plus_r.
@@ -260,7 +260,7 @@ Qed.
(** Equalities and inequalities *)
Lemma parallelogram_id: forall x y,
- (inner (plus x y) (plus x y)) + (inner (minus x y) (minus x y))
+ (inner (@plus E x y) (plus x y)) + (inner (minus x y) (minus x y))
= 2*((inner x x) + (inner y y)).
Proof.
intros x y.
@@ -299,7 +299,7 @@ now rewrite 2!squared_norm.
apply Rmult_le_pos; apply norm_ge_0.
Qed.
-Lemma norm_triangle: forall x y, Hnorm (plus x y) <= Hnorm x + Hnorm y.
+Lemma norm_triangle: forall x y, Hnorm (@plus E x y) <= Hnorm x + Hnorm y.
Proof.
intros x y.
apply Rsqr_incr_0_var; unfold Rsqr.
@@ -361,7 +361,7 @@ apply plus_zero_r.
Qed.
Definition PreHilbert_UniformSpace_mixin :=
- UniformSpace.Mixin E zero ball ball_center ball_sym ball_triangle.
+ UniformSpace.Mixin E (@zero E) ball ball_center ball_sym ball_triangle.
Canonical PreHilbert_UniformSpace :=
UniformSpace.Pack E (PreHilbert_UniformSpace_mixin) E.
@@ -672,13 +672,13 @@ intros g Hg1 Hg2.
unfold Hierarchy.ball; simpl; unfold ball; simpl.
simpl in Hf2, Hg2.
(* . *)
-assert (M:4*Rsqr (norm (minus u (scal (/2) (plus f g)))) + Rsqr (norm (minus f g)) =
- 2*Rsqr (norm (minus u f)) + 2*Rsqr (norm (minus u g))).
+assert (M: 4*Rsqr (norm (@minus E u (scal (/2) (@plus E f g)))) + Rsqr (norm (@minus E f g)) =
+ 2*Rsqr (norm (@minus E u f)) + 2*Rsqr (norm (@minus E u g))).
unfold Rsqr at 3 4; rewrite 2!squared_norm.
rewrite <- Rmult_plus_distr_l.
rewrite <- parallelogram_id.
f_equal.
-apply trans_eq with (Rsqr (2*norm (minus u (scal (/ 2) (plus f g))))).
+apply trans_eq with (Rsqr (2*norm (minus u (scal (/ 2) (@plus E f g))))).
unfold Rsqr; ring.
replace 2 with (abs 2) at 1.
2: apply Rabs_right; left; apply Rlt_0_2.
@@ -694,10 +694,15 @@ replace 2 with (plus 1 1).
2: unfold plus; simpl; ring.
rewrite scal_distr_r scal_one.
unfold minus; rewrite <- 2!plus_assoc; f_equal.
-rewrite opp_plus 2!plus_assoc; f_equal.
-apply plus_comm.
+rewrite <- Rsqr_def.
+do 3 f_equal.
+rewrite (plus_comm (opp f)).
+rewrite <- plus_assoc.
+f_equal.
+rewrite plus_comm.
+apply opp_plus.
rewrite <- squared_norm.
-fold (Rsqr (Hnorm (minus (minus u f) (minus u g)))).
+rewrite <- Rsqr_def.
f_equal.
rewrite <- norm_opp.
unfold norm; simpl; f_equal.
@@ -710,10 +715,10 @@ now rewrite plus_opp_r plus_zero_l.
apply Rsqr_incrst_0.
2: apply norm_ge_0.
2: left; apply cond_pos.
-apply Rle_lt_trans with (-4 * (norm (minus u (scal (/ 2) (plus f g))))² +
+apply Rle_lt_trans with (-4 * (norm (minus u (scal (/ 2) (@plus E f g))))² +
2 * (norm (minus u f))² + 2 * (norm (minus u g))²).
right.
-apply Rplus_eq_reg_l with (4 * (norm (minus u (scal (/ 2) (plus f g))))²).
+apply Rplus_eq_reg_l with (4 * (norm (minus u (scal (/ 2) (@plus E f g))))²).
rewrite M; ring.
apply Rle_lt_trans with (-4*Rsqr delta+2 * (norm (minus u f))² +2 * (norm (minus u g))²).
apply Rplus_le_compat_r, Rplus_le_compat_r.
@@ -723,10 +728,10 @@ apply Ropp_le_contravar.
apply Rmult_le_pos; left; apply Rlt_0_2.
apply Rsqr_incr_1; try assumption.
assert (H2:Rbar_le (Glb_Rbar (fun r : R => exists w0 : E, phi w0
- /\ r = norm (minus u w0))) (norm (minus u (scal (/ 2) (plus f g))))).
+ /\ r = norm (@minus E u w0))) (norm (@minus E u (scal (/ 2) (@plus E f g))))).
apply Glb_Rbar_correct.
-exists (scal (/ 2) (plus f g)); split; try easy.
-replace (scal (/ 2) (plus f g)) with
+exists (scal (/ 2) (@plus E f g)); split; try easy.
+replace (scal (/ 2) (@plus E f g)) with
(plus (scal (/2) f) (scal (1-/2) g)).
apply phi_convex; try assumption.
split.
@@ -865,7 +870,7 @@ intros H u v v' H0 H1 H2 H3.
rewrite <- H3 in H1.
pose (a := minus u v').
pose (b := minus u v).
-pose (v'' := scal (1/2) (plus v v')).
+pose (v'' := scal (1/2) (@plus E v v')).
pose (d := norm (minus u v)).
assert (E1 : plus (4*((norm (minus u v''))*(norm (minus u v''))))
((norm (minus v v'))*(norm (minus v v')))
@@ -905,12 +910,12 @@ rewrite plus_comm.
rewrite (plus_comm _ ((norm (minus v v') * norm (minus v v')))).
apply Rplus_eq_compat_l.
replace (plus (minus u v') (minus u v))
- with (minus ((scal 2) u) (plus v v')).
+ with (minus ((scal 2) u) (@plus E v v')).
unfold minus.
-assert (H42 : (4 *(norm (plus u (opp (scal (1 / 2) (plus v v')))) *
- norm (plus u (opp (scal (1 / 2) (plus v v'))))))
- =((2*(norm (plus u (opp (scal (1 / 2) (plus v v'))))))*
- (2*(norm (plus u (opp (scal (1 / 2) (plus v v')))))))).
+assert (H42 : (4 *(norm (@plus E u (opp (scal (1 / 2) (@plus E v v')))) *
+ norm (@plus E u (opp (scal (1 / 2) (@plus E v v'))))))
+ =((2*(norm (@plus E u (opp (scal (1 / 2) (@plus E v v'))))))*
+ (2*(norm (@plus E u (opp (scal (1 / 2) (@plus E v v')))))))).
ring.
rewrite H42.
replace 2 with (Rabs 2) at 1.
@@ -954,12 +959,9 @@ rewrite plus_comm.
rewrite opp_opp.
rewrite (plus_comm (opp u) v).
rewrite plus_assoc.
-rewrite <- (plus_assoc u (opp v') v).
-rewrite (plus_comm u (plus (opp v') v)).
-rewrite <- (plus_assoc (plus (opp v') v) u (opp u)).
-rewrite plus_opp_r.
-rewrite plus_zero_r.
-reflexivity.
+rewrite (plus_comm _ (opp u)).
+repeat rewrite plus_assoc.
+now rewrite plus_opp_l plus_zero_l.
assert (Hmin : norm (minus v v') * norm (minus v v') <= 0).
replace 0 with (plus (4*(d*d)) (-4*(d*d))).
replace (norm (minus v v') * norm (minus v v')) with
@@ -968,8 +970,8 @@ replace (norm (minus v v') * norm (minus v v')) with
unfold minus.
apply Rplus_le_compat_l.
replace (-4 * (d * d)) with (opp (4*(d*d))).
-assert (Has: opp (4 * (norm (plus u (opp v'')) * norm (plus u (opp v''))))
- = (-((4 * (norm (plus u (opp v'')) * norm (plus u (opp v''))))))).
+assert (Has: opp (4 * (norm (@plus E u (opp v'')) * norm (@plus E u (opp v''))))
+ = (-((4 * (norm (@plus E u (opp v'')) * norm (@plus E u (opp v''))))))).
reflexivity.
rewrite Has.
clear Has.
@@ -1001,7 +1003,7 @@ assert (Hf :Rbar_le x (norm (minus u v''))).
apply Hp.
unfold v''.
unfold convex in phi_convex.
-replace (scal (1 / 2) (plus v v'))
+replace (scal (1 / 2) (@plus E v v'))
with
(plus (scal (1/2) v) (scal (1/2) v')).
replace (1 / 2) with (1 - (1 / 2)) at 2 by field.
@@ -1123,7 +1125,7 @@ intros. unfold minus at 1.
rewrite scal_distr_r. rewrite scal_one. rewrite scal_opp_l.
rewrite opp_plus. rewrite (opp_plus v (opp (scal t v))).
rewrite opp_opp. rewrite plus_assoc. rewrite plus_assoc.
-rewrite (plus_comm (plus u (opp (scal t w))) (opp v)).
+rewrite (plus_comm (@plus E u (opp (scal t w))) (opp v)).
rewrite plus_assoc. unfold minus at 1.
rewrite (plus_comm (opp v) u). unfold minus.
rewrite <- plus_assoc. rewrite <- scal_opp_l.
@@ -1227,7 +1229,7 @@ Lemma charac_ortho_projection_convex_aux1_r : forall u v w :E, phi v ->
<= ((norm (minus u w) * norm (minus u w)))).
intros.
assert (minus u w = plus (minus u v) (minus v w)).
-unfold minus. rewrite (plus_comm u (opp v)).
+unfold minus. rewrite (@plus_comm E u (opp v)).
rewrite plus_assoc_gen. rewrite plus_opp_l.
rewrite plus_zero_l. reflexivity.
rewrite H2.
@@ -1332,7 +1334,7 @@ intros; split.
intros.
assert
(forall w, minus u w = plus (minus u v) (minus v w)).
- unfold minus. symmetry. rewrite (plus_comm u (opp v)).
+ unfold minus. symmetry. rewrite (@plus_comm E u (opp v)).
rewrite plus_assoc_gen. rewrite plus_opp_l.
rewrite plus_zero_l. reflexivity.
assert
@@ -1433,7 +1435,7 @@ apply mod2.
trivial.
apply unique_existence1; split.
apply ortho_projection_convex'; try easy.
-exists zero.
+exists (@zero E).
apply (compatible_m_zero phi phi_mod).
intros v v' Hv Hv'.
apply (ortho_projection_convex_unique phi phi_convex phi_compl u); intuition.
@@ -1489,7 +1491,7 @@ rewrite inner_scal_l.
rewrite (H1 _ Hy); ring.
Qed.
-Lemma trivial_orth_compl : forall u : E, ((forall v : E, inner u v = 0) <-> u = zero).
+Lemma trivial_orth_compl : forall u : E, ((forall v : E, inner u v = 0) <-> u = @zero E).
intros; split.
intro H0. assert (inner u u = 0). apply H0; trivial.
apply PreHilbert.ax3 in H. trivial.
@@ -1497,7 +1499,7 @@ intros. rewrite H. rewrite inner_zero_l; reflexivity.
Qed.
Lemma trivial_orth_compl' : forall (phi : E -> Prop) (u : E),
- closed phi -> phi u -> ((forall v : E, phi v -> inner u v = 0) <-> u = zero).
+ closed phi -> phi u -> ((forall v : E, phi v -> inner u v = 0) <-> u = @zero E).
intros; split.
intro H1. assert (inner u u = 0). apply H1; trivial.
apply PreHilbert.ax3 in H2. trivial.
@@ -1524,7 +1526,7 @@ split; intros.
+ assert
(forall w:E, phi w -> inner (minus u v) (minus w v) <= 0).
apply charac_ortho_projection_subspace1; intuition.
- pose (w' := plus w v).
+ pose (w' := @plus E w v).
assert (inner (minus u v) w <= 0).
assert (Hm1 : w = minus w' v).
unfold minus.
@@ -1670,7 +1672,7 @@ Qed.
Lemma direct_sum_with_orth_compl_charac2: forall u v,
phi v -> norm (minus u v)
= Glb_Rbar (fun r => exists w:E, phi w /\ r = norm (minus u w)) ->
- (orth_compl u <-> v = zero).
+ (orth_compl u <-> v = @zero E).
split; intros.
+ unfold orth_compl in H1.
assert
@@ -1688,7 +1690,7 @@ split; intros.
intros.
apply direct_sum_with_orth_compl_decomp in H0.
assert
- (forall x : E, phi x -> orth_compl x -> x = zero).
+ (forall x : E, phi x -> orth_compl x -> x = @zero E).
apply direct_sumable_with_orth_compl. unfold orth_compl in H0.
assert (inner (minus u zero) (minus w zero) = 0).
unfold minus. rewrite opp_zero.
diff --git a/coq/CertRL/LM/lax_milgram.v b/rocq/CertRL/LM/lax_milgram.v
similarity index 98%
rename from coq/CertRL/LM/lax_milgram.v
rename to rocq/CertRL/LM/lax_milgram.v
index 1d860f94..83e54f31 100644
--- a/coq/CertRL/LM/lax_milgram.v
+++ b/rocq/CertRL/LM/lax_milgram.v
@@ -206,7 +206,7 @@ Theorem Riesz_Frechet'_zero_phi : forall (f:topo_dual E),
phi v -> f v = inner u v.
Proof.
intros f H.
-exists zero.
+exists (@zero E).
split.
destruct m_C as ((CG1,(z,CG2)),CS).
unfold compatible_m in m_C.
@@ -242,9 +242,9 @@ apply C1; trivial.
apply C1'; trivial.
exists zero.
split.
-apply compatible_m_zero.
+apply (compatible_m_zero phi0).
apply C.
-apply compatible_m_zero.
+apply (compatible_m_zero phi0').
apply C'.
intros x l.
split.
@@ -316,7 +316,7 @@ assert (forall u, exists! v:E,
= Glb_Rbar (fun r => exists w:E, PHI w
/\ r = norm (minus u w))).
intros u; apply: ortho_projection_convex.
-exists zero.
+exists (@zero E).
split.
(*apply ker_nnker_equiv.*)
apply: compatible_m_zero.
@@ -772,7 +772,7 @@ exfalso.
apply H.
exists v.
split; trivial.
-assert (u = zero).
+assert (u = @zero E).
rewrite <- H0.
assert (exists! (u:E), phi u /\ forall (v:E),
phi v -> f v = inner u v).
@@ -803,7 +803,7 @@ trivial.
apply H2'.
rewrite H2.
rewrite norm_zero.
-assert (forall u, u <> zero -> phi u -> norm (f u) <= 0 * norm u).
+assert (forall u, u <> @zero E -> phi u -> norm (f u) <= 0 * norm u).
intros u0 H3 H3'.
unfold f_phi_neq_zero.
rewrite H1.
@@ -861,7 +861,7 @@ apply H0.
apply (iota_elim _ _ ) in H0.
unfold tau.
rewrite H0.
-assert (u <> zero).
+assert (u <> @zero E).
rewrite Heq.
destruct H1 as (H11,H12).
assert (f v = inner u1 v).
@@ -1720,7 +1720,7 @@ destruct H as (H,H2).
rewrite <- plus_assoc in H.
destruct H as (Pu,H).
symmetry in H.
-assert (Hu0 : plus u zero = u).
+assert (Hu0 : @plus E u zero = u).
rewrite plus_zero_r. reflexivity.
rewrite <- Hu0 in H at 1.
apply plus_reg_l in H.
@@ -1812,7 +1812,7 @@ rewrite <- plus_assoc.
rewrite <- plus_assoc.
rewrite plus_comm.
rewrite (plus_comm (opp v') _).
-assert (Hp1 : forall a b, plus (plus a b) v
+assert (Hp1 : forall a b, plus (@plus E a b) v
= plus a (plus b v)).
intros.
rewrite plus_assoc.
@@ -1820,7 +1820,7 @@ reflexivity.
rewrite Hp1.
rewrite Hp1.
rewrite Hp1.
-rewrite (plus_assoc v (opp v') _).
+rewrite (@plus_assoc E v (opp v') _).
assert (Hp2 : forall a b, plus a (plus b
(plus (opp v') v)) =
plus (plus a b) (plus (opp v') v)).
@@ -1830,30 +1830,30 @@ rewrite Hp2.
rewrite Hp2.
rewrite plus_comm.
rewrite (plus_comm (opp v') v).
-assert (Hpaux : ((plus (opp (scal r (Tau (A v))))
+assert (Hpaux : ((@plus E (opp (scal r (Tau (A v))))
(plus (scal r (Tau f))
(plus (opp (opp (scal r (Tau (A v')))))
(opp (scal r (Tau f)))))))
- = (opp (scal r (Tau (A (plus v (opp v'))))))).
+ = (opp (scal r (Tau (A (@plus E v (opp v'))))))).
rewrite plus_assoc.
rewrite opp_opp.
rewrite (plus_comm _ (opp (scal r (Tau f)))).
-assert (Hg : forall a b c d : E, plus (plus a b) (plus c d)
- = plus (plus a d) (plus b c)).
+assert (Hg : forall a b c d : E, @plus E (@plus E a b) (@plus E c d)
+ = @plus E (@plus E a d) (@plus E b c)).
intros.
rewrite plus_assoc.
rewrite plus_assoc.
rewrite plus_comm.
-rewrite (plus_comm a0 d).
-rewrite <- (plus_assoc d a0 b).
-rewrite <- (plus_assoc d _ _).
+rewrite (@plus_comm E a0 d).
+rewrite <- (@plus_assoc E d a0 b).
+rewrite <- (@plus_assoc E d _ _).
reflexivity.
rewrite Hg.
rewrite plus_opp_r.
rewrite plus_zero_r.
-replace (opp (scal r (Tau (A (plus v (opp v'))))))
+replace (opp (scal r (Tau (A (@plus E v (opp v'))))))
with
- (scal r (Tau (A (opp (plus v (opp v')))))).
+ (scal r (Tau (A (opp (@plus E v (opp v')))))).
replace ((opp (scal r (Tau (A v)))))
with
(scal r (Tau (A (opp v)))).
@@ -1891,14 +1891,14 @@ rewrite Hsr.
reflexivity.
rewrite <- scal_opp_r.
rewrite <- scal_opp_one.
-replace ((scal (opp one) (Tau (A (plus v (opp v'))))))
+replace ((scal (opp one) (Tau (A (@plus E v (opp v'))))))
with
- ((Tau (scal (opp one) (A (plus v (opp v')))))).
+ ((Tau (scal (opp one) (A (@plus E v (opp v')))))).
assert (H : is_linear_mapping Tau).
apply Riesz_Frechet_moreover2_phi; trivial.
destruct H as (Hl1,Hl2).
-replace (Tau (A (opp (plus v (opp v')))))
- with (Tau (scal (opp one) (A (plus v (opp v'))))).
+replace (Tau (A (opp (@plus E v (opp v')))))
+ with (Tau (scal (opp one) (A (@plus E v (opp v'))))).
reflexivity.
clear Hp1 Hp2 Hg.
assert (Hlb : forall l y, scal l (A y)
@@ -2353,7 +2353,7 @@ intros u' Hu'.
symmetry.
apply (proj2 Hu).
trivial.
-exists zero.
+exists (@zero E).
split.
unfold is_sol_linear_pb_phi.
split.
@@ -2363,12 +2363,12 @@ intros v Hv.
apply Is_only_zero_set_correct1_phi with E phi v in i.
rewrite i.
destruct Hba as ((Hba1,(Hba2,Hba3)),Hba4).
-assert (a zero zero = a (scal zero zero) zero).
+assert (a (@zero E) (@zero E) = a (scal zero zero) (@zero E)).
rewrite scal_zero_l.
reflexivity.
rewrite H.
-replace (a (scal zero zero) zero)
- with (scal zero (a zero zero)).
+replace (a (scal zero zero) (@zero E))
+ with (scal zero (a (@zero E) (@zero E))).
rewrite scal_zero_l.
assert (is_linear_mapping f).
apply f.
diff --git a/coq/CertRL/LM/lax_milgram_cea.v b/rocq/CertRL/LM/lax_milgram_cea.v
similarity index 94%
rename from coq/CertRL/LM/lax_milgram_cea.v
rename to rocq/CertRL/LM/lax_milgram_cea.v
index 032ae62d..bb73b3d4 100644
--- a/coq/CertRL/LM/lax_milgram_cea.v
+++ b/rocq/CertRL/LM/lax_milgram_cea.v
@@ -138,10 +138,11 @@ destruct H0 as (H0,H7).
destruct H7 as (H7,(H8,H9)).
unfold minus.
rewrite H9.
-replace (a (plus u (opp uh)) u) with
- (plus (a (plus u (opp uh)) u) 0).
+transitivity (plus (a (@plus E u (opp uh)) u) 0); cycle 1.
+{ unfold plus; simpl.
+ apply Rplus_0_r.
+}
f_equal.
-now rewrite plus_zero_r.
specialize (H1 uh).
apply H1 in X1.
specialize (H2 uh).
@@ -165,10 +166,12 @@ rewrite H7 scal_opp_one.
reflexivity.
rewrite scal_opp_one.
reflexivity.
-rewrite plus_zero_r; reflexivity.
-replace (M * norm (minus u vh) * norm (minus u uh)) with
- (M * norm (minus u uh) * norm (minus u vh)) by ring.
-assumption.
+eapply Rle_trans; try eapply H4.
+right.
+simpl.
+repeat rewrite Rmult_assoc.
+f_equal.
+now rewrite Rmult_comm.
destruct Hca.
intro Hk.
rewrite Hk in H0.
diff --git a/coq/CertRL/LM/linear_map.v b/rocq/CertRL/LM/linear_map.v
similarity index 88%
rename from coq/CertRL/LM/linear_map.v
rename to rocq/CertRL/LM/linear_map.v
index 16d90a67..94b708f3 100644
--- a/coq/CertRL/LM/linear_map.v
+++ b/rocq/CertRL/LM/linear_map.v
@@ -66,15 +66,21 @@ Lemma fct_plus_opp_r: forall f:E->F, fct_plus f (fct_opp f) = fct_zero.
Proof.
intros f.
apply functional_extensionality.
-intros x; apply plus_opp_r.
+intros x; apply (@plus_opp_r F).
Qed.
+Definition fct_AbelianMonoid_mixin :=
+ AbelianMonoid.Mixin (E->F) fct_plus fct_zero fct_plus_comm
+ fct_plus_assoc fct_plus_zero_r.
+
+Canonical fct_AbelianMonoid :=
+ AbelianMonoid.Pack (E->F) (fct_AbelianMonoid_mixin) (E->F).
+
Definition fct_AbelianGroup_mixin :=
- AbelianGroup.Mixin (E->F) fct_plus fct_opp fct_zero fct_plus_comm
- fct_plus_assoc fct_plus_zero_r fct_plus_opp_r.
+ AbelianGroup.Mixin fct_AbelianMonoid fct_opp fct_plus_opp_r.
Canonical fct_AbelianGroup :=
- AbelianGroup.Pack (E->F) (fct_AbelianGroup_mixin) (E->F).
+ AbelianGroup.Pack (E->F) (AbelianGroup.Class _ (fct_AbelianMonoid_mixin) fct_AbelianGroup_mixin) (E->F).
Lemma fct_scal_assoc: forall x y (u:E->F),
fct_scal x (fct_scal y u) = fct_scal (mult x y) u.
@@ -138,7 +144,10 @@ split.
unfold plus at 1 4 5; unfold opp; simpl.
unfold fct_plus, fct_opp.
rewrite Hf1, Hg1.
- rewrite opp_plus.
+ transitivity (plus (plus (f x) (f y)) (plus (opp (g x)) (opp (g y)))).
+ { f_equal.
+ apply (opp_plus (g x) (g y)).
+ }
repeat rewrite <- plus_assoc.
apply f_equal.
repeat rewrite plus_assoc.
@@ -193,7 +202,7 @@ Proof.
intros f (H1,H2); split.
intros x y; unfold opp; simpl; unfold fct_opp.
rewrite H1.
-apply opp_plus.
+apply (opp_plus (f x) (f y)).
intros x l; unfold opp; simpl; unfold fct_opp.
rewrite H2.
now rewrite <- scal_opp_r.
@@ -254,13 +263,23 @@ intros x y l;unfold plus; unfold opp; simpl;unfold fct_plus, fct_opp;unfold plus
unfold fct_plus, fct_opp;rewrite Hf2,Hg2;rewrite <- scal_opp_r;now rewrite scal_distr_l.
split.
intros x y z; unfold plus at 1 4 5; unfold opp;simpl; unfold fct_plus, fct_opp;
- unfold plus at 1 5 6;unfold opp;simpl;unfold fct_plus, fct_opp;rewrite Hf3,Hg3;
- rewrite opp_plus;rewrite plus_assoc;rewrite plus_assoc;apply f_equal2;trivial.
- rewrite <- plus_assoc;rewrite (plus_comm (f y z) (opp (g x z)));now rewrite plus_assoc.
-intros x y z; unfold plus at 1 4 5; unfold opp;simpl; unfold fct_plus, fct_opp.
- unfold plus at 1 4 5;unfold opp;simpl;unfold fct_plus, fct_opp. rewrite Hf4,Hg4;
- rewrite opp_plus;rewrite plus_assoc;rewrite plus_assoc;apply f_equal2;trivial.
- rewrite <- plus_assoc;rewrite (plus_comm (f x z) (opp (g x y)));now rewrite plus_assoc.
+ unfold plus at 1 5 6;unfold opp;simpl;unfold fct_plus, fct_opp;rewrite Hf3,Hg3.
+repeat rewrite <- plus_assoc.
+f_equal.
+rewrite (plus_comm (opp (g x z))).
+repeat rewrite <- plus_assoc.
+f_equal.
+rewrite (plus_comm (opp _)).
+apply (opp_plus (g x z) (g y z)).
+intros x y z; unfold plus at 1 4 5; unfold opp;simpl; unfold fct_plus, fct_opp;
+ unfold plus at 1 4 5;unfold opp;simpl;unfold fct_plus, fct_opp. rewrite Hf4,Hg4.
+repeat rewrite <- plus_assoc.
+f_equal.
+rewrite (plus_comm (opp (g x y))).
+repeat rewrite <- plus_assoc.
+f_equal.
+rewrite (plus_comm (opp _)).
+apply (opp_plus (g x y) (g x z)).
exists zero;split.
unfold zero;intros;simpl;unfold fct_zero; simpl;unfold zero;simpl;unfold fct_zero;
now rewrite scal_zero_r.
diff --git a/coq/CertRL/LM/logic_tricks.v b/rocq/CertRL/LM/logic_tricks.v
similarity index 96%
rename from coq/CertRL/LM/logic_tricks.v
rename to rocq/CertRL/LM/logic_tricks.v
index 5190454a..d350b6cf 100644
--- a/coq/CertRL/LM/logic_tricks.v
+++ b/rocq/CertRL/LM/logic_tricks.v
@@ -118,17 +118,16 @@ intros k Hk.
replace k with 0%nat.
apply H.
intros m Hm; contradict Hm.
-apply lt_n_0.
+auto with arith.
generalize Hk; case k; trivial.
intros m Hm; contradict Hm.
-apply le_not_lt.
-now auto with arith.
+now auto with arith.
intros k Hk.
apply H.
intros m Hm.
apply IHn.
-apply lt_le_trans with (1:=Hm).
-now apply gt_S_le.
+apply Nat.lt_le_trans with (1:=Hm).
+auto with arith.
apply H0.
apply le_n.
Qed.
diff --git a/coq/CertRL/README.md b/rocq/CertRL/README.md
similarity index 100%
rename from coq/CertRL/README.md
rename to rocq/CertRL/README.md
diff --git a/coq/CertRL/cond_expt.v b/rocq/CertRL/cond_expt.v
similarity index 100%
rename from coq/CertRL/cond_expt.v
rename to rocq/CertRL/cond_expt.v
diff --git a/coq/CertRL/finite_time.v b/rocq/CertRL/finite_time.v
similarity index 100%
rename from coq/CertRL/finite_time.v
rename to rocq/CertRL/finite_time.v
diff --git a/coq/CertRL/mdp.v b/rocq/CertRL/mdp.v
similarity index 99%
rename from coq/CertRL/mdp.v
rename to rocq/CertRL/mdp.v
index 5c721f05..eede0e86 100644
--- a/coq/CertRL/mdp.v
+++ b/rocq/CertRL/mdp.v
@@ -61,12 +61,12 @@ Record MDP := mkMDP {
act : forall s: state, Type;
(** The state space has decidable equality.*)
- st_eqdec :> EqDec state eq;
- act_eqdec :> (forall s, EqDec (act s) eq);
+ st_eqdec ::> EqDec state eq;
+ act_eqdec ::> (forall s, EqDec (act s) eq);
(** The state and action spaces are finite. *)
- fs :> FiniteType (state) ;
- fa :> forall s, FiniteType (act s);
+ fs ::> FiniteType (state) ;
+ fa ::> forall s, FiniteType (act s);
(** The state space and the fibered action spaces are nonempty. *)
ne : NonEmpty (state) ;
@@ -96,8 +96,6 @@ Proof.
eapply FiniteType_fun_dep ; eauto.
- apply fs.
- apply fa.
- Unshelve.
- apply st_eqdec.
Qed.
Global Instance act_finite (M : MDP) : FiniteType (sigT M.(act))
@@ -407,12 +405,19 @@ Proof.
lra.
Qed.
+
+Definition Rfct_AbelianMonoid_mixin :=
+ AbelianMonoid.Mixin (Rfct A) Rfct_plus Rfct_zero Rfct_plus_comm
+ Rfct_plus_assoc Rfct_plus_zero_r.
+
+Canonical Rfct_AbelianMonoid :=
+ AbelianMonoid.Pack (Rfct A) (Rfct_AbelianMonoid_mixin) (Rfct A).
+
Definition Rfct_AbelianGroup_mixin :=
- AbelianGroup.Mixin (Rfct A) Rfct_plus Rfct_opp Rfct_zero Rfct_plus_comm
- Rfct_plus_assoc Rfct_plus_zero_r Rfct_plus_opp_r.
+ AbelianGroup.Mixin Rfct_AbelianMonoid Rfct_opp Rfct_plus_opp_r.
Canonical Rfct_AbelianGroup :=
- AbelianGroup.Pack (Rfct A) (Rfct_AbelianGroup_mixin) (Rfct A).
+ AbelianGroup.Pack (Rfct A) (AbelianGroup.Class _ (Rfct_AbelianMonoid_mixin) Rfct_AbelianGroup_mixin) (Rfct A).
End Rfct_AbelianGroup.
diff --git a/coq/CertRL/mdp_turtle.v b/rocq/CertRL/mdp_turtle.v
similarity index 96%
rename from coq/CertRL/mdp_turtle.v
rename to rocq/CertRL/mdp_turtle.v
index d296dd12..d8168f26 100644
--- a/coq/CertRL/mdp_turtle.v
+++ b/rocq/CertRL/mdp_turtle.v
@@ -33,14 +33,14 @@ Section turtle.
Definition turtle_state max_x max_y := prod ({x:nat | x < max_x}%nat) ({y:nat | y < max_y}%nat).
(* Convenience method for creating a state with known in-bounds constant co-ordinates *)
- Definition make_turtle_state max_x max_y x y : if lt_dec x max_x
- then if lt_dec y max_y
+ Definition make_turtle_state max_x max_y x y : if Compare_dec.lt_dec x max_x
+ then if Compare_dec.lt_dec y max_y
then turtle_state max_x max_y
else True
else True.
Proof.
- destruct (lt_dec x max_x).
- - destruct (lt_dec y max_y).
+ destruct (Compare_dec.lt_dec x max_x).
+ - destruct (Compare_dec.lt_dec y max_y).
+ apply pair.
* exists x; trivial.
* exists y; trivial.
@@ -126,9 +126,9 @@ Section turtle.
:= (let '(x,y) := s in
match a with
| Up => if proj1_sig y == 0 then None else Some (x, y-1)
- | Down => if lt_dec (y+1) max_y then Some (x, y+1) else None
+ | Down => if Compare_dec.lt_dec (y+1) max_y then Some (x, y+1) else None
| Left => if proj1_sig x == 0 then None else Some (x-1, y)
- | Right => if lt_dec (x+1) max_x then Some (x+1, y) else None
+ | Right => if Compare_dec.lt_dec (x+1) max_x then Some (x+1, y) else None
end)%nat.
Next Obligation.
lia.
@@ -253,7 +253,7 @@ End optimal_path.
Section to_string.
Section utils.
- Definition newline := String (Ascii.ascii_of_N 10) EmptyString.
+ Definition newline := String (Ascii.ascii_of_N (Npos 10%positive)) EmptyString.
Definition string_bracket (sstart send:string) (smiddle:string)
:= String.append sstart (String.append smiddle send).
diff --git a/coq/CertRL/orderfun.v b/rocq/CertRL/orderfun.v
similarity index 100%
rename from coq/CertRL/orderfun.v
rename to rocq/CertRL/orderfun.v
diff --git a/coq/CertRL/pmf_monad.v b/rocq/CertRL/pmf_monad.v
similarity index 100%
rename from coq/CertRL/pmf_monad.v
rename to rocq/CertRL/pmf_monad.v
diff --git a/coq/CertRL/pmf_prob.v b/rocq/CertRL/pmf_prob.v
similarity index 98%
rename from coq/CertRL/pmf_prob.v
rename to rocq/CertRL/pmf_prob.v
index 3d67c646..6e9f1b10 100644
--- a/coq/CertRL/pmf_prob.v
+++ b/rocq/CertRL/pmf_prob.v
@@ -127,7 +127,7 @@ Section Pmf_PMF.
rewrite Forall_forall; intros lmu.
specialize (lmu _ inn).
lia.
- -- destruct (in_dec eq_nat_dec a l); trivial.
+ -- destruct (in_dec Peano_dec.eq_nat_dec a l); trivial.
specialize (nin _ n).
unfold nequiv_decb, negb, equiv_decb in non0.
destruct (equiv_dec (coll a) 0); congruence.
@@ -283,7 +283,7 @@ Section Pmf_PMF.
(map (fun x : nonnegreal * A => nonneg (fst x))
(filter (fun x : nonnegreal * A => if equiv_dec a (snd x) then true else false) outcomes))
| None => 0
- end) (nodup eq_nat_dec (map countable_index (map snd outcomes))))).
+ end) (nodup Peano_dec.eq_nat_dec (map countable_index (map snd outcomes))))).
- apply infinite_sum'_finite
; intros.
+ match_case; intros.
@@ -389,7 +389,7 @@ Section pmf_prob.
(exist (fun _ : pre_event A => True)
(fun omega : A => pre_event_singleton c (rv_X omega))
(sa_preimage_singleton rv_X c)))
- (nodup eq_nat_dec (map countable_index (map snd pmf.(outcomes)))))
+ (nodup Peano_dec.eq_nat_dec (map countable_index (map snd pmf.(outcomes)))))
; intros HH.
cut_to HH.
- rewrite (infinite_sum'_unique i HH).
diff --git a/coq/CertRL/qvalues.v b/rocq/CertRL/qvalues.v
similarity index 100%
rename from coq/CertRL/qvalues.v
rename to rocq/CertRL/qvalues.v
diff --git a/coq/CertRL/refs.md b/rocq/CertRL/refs.md
similarity index 100%
rename from coq/CertRL/refs.md
rename to rocq/CertRL/refs.md
diff --git a/coq/ProbTheory/Almost.v b/rocq/ProbTheory/Almost.v
similarity index 99%
rename from coq/ProbTheory/Almost.v
rename to rocq/ProbTheory/Almost.v
index b6b73a82..8723f05a 100644
--- a/coq/ProbTheory/Almost.v
+++ b/rocq/ProbTheory/Almost.v
@@ -481,7 +481,7 @@ Section almostR2_part.
apply all_almost; intros ω Pω.
exists N; trivial.
- intros.
- apply le_dec.
+ apply Compare_dec.le_dec.
- trivial.
Qed.
diff --git a/coq/ProbTheory/BorelSigmaAlgebra.v b/rocq/ProbTheory/BorelSigmaAlgebra.v
similarity index 100%
rename from coq/ProbTheory/BorelSigmaAlgebra.v
rename to rocq/ProbTheory/BorelSigmaAlgebra.v
diff --git a/coq/ProbTheory/ConditionalExpectation.v b/rocq/ProbTheory/ConditionalExpectation.v
similarity index 100%
rename from coq/ProbTheory/ConditionalExpectation.v
rename to rocq/ProbTheory/ConditionalExpectation.v
diff --git a/coq/ProbTheory/DiscreteProbSpace.v b/rocq/ProbTheory/DiscreteProbSpace.v
similarity index 99%
rename from coq/ProbTheory/DiscreteProbSpace.v
rename to rocq/ProbTheory/DiscreteProbSpace.v
index 820341c9..77abce95 100644
--- a/coq/ProbTheory/DiscreteProbSpace.v
+++ b/rocq/ProbTheory/DiscreteProbSpace.v
@@ -1209,7 +1209,7 @@ Section countable_products.
intros.
apply Rge_minus.
unfold double_sum.
- destruct (lt_dec n2 n1).
+ destruct (Compare_dec.lt_dec n2 n1).
- rewrite (sum_f_R0_split _ n1 n2); trivial.
apply Rge_trans with (r2 := sum_f_R0 (fun i : nat => sum_f_R0 (fun j : nat => f i j) m1) n2).
+ rewrite <- Rplus_0_r.
@@ -1220,7 +1220,7 @@ Section countable_products.
intros; apply H.
+ apply Rle_ge, sum_f_R0_le.
intros.
- destruct (lt_dec m2 m1).
+ destruct (Compare_dec.lt_dec m2 m1).
* rewrite (sum_f_R0_split _ m1 m2); trivial.
rewrite <- Rplus_0_r at 1.
apply Rplus_le_compat_l.
@@ -1230,7 +1230,7 @@ Section countable_products.
- assert (n1 = n2) by lia; subst.
apply Rle_ge, sum_f_R0_le.
intros.
- destruct (lt_dec m2 m1).
+ destruct (Compare_dec.lt_dec m2 m1).
+ rewrite (sum_f_R0_split _ m1 m2); trivial.
apply Rplus_le_pos_l.
apply sum_f_R0_nneg.
@@ -1284,7 +1284,7 @@ Section countable_products.
Rabs ((double_sum f m n) - (double_sum f n n)) =
Rabs ((double_sum f m m) - (double_sum f n n))).
{
- destruct (ge_dec m n)%nat.
+ destruct (Compare_dec.ge_dec m n)%nat.
rewrite Rabs_right, Rabs_right, Rabs_right; try lra.
apply double_sum_ge; trivial; lia.
apply double_sum_ge; trivial; lia.
@@ -1385,7 +1385,7 @@ Section countable_products.
rewrite <- Lim_seq_const.
apply Lim_seq_le_loc.
exists n; intros.
- destruct (lt_dec n n0).
+ destruct (Compare_dec.lt_dec n n0).
generalize (sum_f_R0_split f n0 n); intros.
rewrite H1; try lia.
apply Rplus_le_pos_l.
diff --git a/coq/ProbTheory/Dynkin.v b/rocq/ProbTheory/Dynkin.v
similarity index 99%
rename from coq/ProbTheory/Dynkin.v
rename to rocq/ProbTheory/Dynkin.v
index fd42f9d0..f819f2b3 100644
--- a/coq/ProbTheory/Dynkin.v
+++ b/rocq/ProbTheory/Dynkin.v
@@ -24,7 +24,7 @@ Section dynkin.
Class Lambda_system (c:pre_event T -> Prop)
:= mk_lambda_system {
lambda_Ω : c pre_Ω
- ; lambda_proper :> Proper (pre_event_equiv ==> iff) c
+ ; lambda_proper ::> Proper (pre_event_equiv ==> iff) c
; lambda_complement {a} : c a -> c (pre_event_complement a)
; lambda_disjoint_countable_union (an : nat -> pre_event T) :
(forall x, c (an x)) ->
@@ -654,7 +654,7 @@ Section monotone_class_def.
Class monotone_class (M : pre_event E -> Prop)
:= mk_monotone_class {
monotone_Ω : M pre_Ω
- ; monotone_proper :> Proper (pre_event_equiv ==> iff) M
+ ; monotone_proper ::> Proper (pre_event_equiv ==> iff) M
; monotone_diff a b :
M a -> M b ->
pre_event_sub a b ->
diff --git a/coq/ProbTheory/Event.v b/rocq/ProbTheory/Event.v
similarity index 99%
rename from coq/ProbTheory/Event.v
rename to rocq/ProbTheory/Event.v
index a2dac2b3..01b05fba 100644
--- a/coq/ProbTheory/Event.v
+++ b/rocq/ProbTheory/Event.v
@@ -2151,7 +2151,7 @@ Section event.
unfold collection_take.
split.
- intros na.
- destruct (lt_dec a n).
+ destruct (Compare_dec.lt_dec a n).
+ split; trivial.
destruct (map_nth_in_exists En (seq 0 n) event_none a).
* now rewrite seq_length.
@@ -2357,7 +2357,7 @@ Section vec.
Definition pre_bounded_inter_of_pre_collection_unbound {T} {n} (collection: forall i (pf:(i match lt_dec i n with
+ (fun i => match Compare_dec.lt_dec i n with
| left pf => collection i pf
| right _ => pre_Ω
end).
@@ -2396,7 +2396,7 @@ Section vec.
Definition bounded_inter_of_collection_unbound {T} {n} {σ: SigmaAlgebra T} (collection: forall i (pf:(i match lt_dec i n with
+ (fun i => match Compare_dec.lt_dec i n with
| left pf => collection i pf
| right _ => Ω
end).
@@ -2719,7 +2719,7 @@ Section pre_make_disjoint.
Definition make_pre_collection_disjoint (coll:nat->pre_event T) : nat -> pre_event T
:= fun x => pre_event_diff (coll x) ((pre_union_of_collection (fun y =>
- if lt_dec y x
+ if Compare_dec.lt_dec y x
then coll y
else pre_event_none))).
@@ -2755,13 +2755,13 @@ Section pre_make_disjoint.
intros y ylt cy.
apply H2.
exists y.
- destruct (lt_dec y x); intuition.
+ destruct (Compare_dec.lt_dec y x); intuition.
- intros [ce fce].
unfold make_pre_collection_disjoint.
split; trivial.
unfold pre_union_of_collection.
intros [n Hn].
- destruct (lt_dec n x); trivial.
+ destruct (Compare_dec.lt_dec n x); trivial.
eapply fce; eauto.
Qed.
@@ -2773,7 +2773,7 @@ Section pre_make_disjoint.
apply make_pre_collection_disjoint_in in e2.
destruct e1 as [H11 H12].
destruct e2 as [H21 H22].
- destruct (not_eq _ _ xyneq) as [xlt|ylt].
+ destruct (Compare_dec.not_eq _ _ xyneq) as [xlt|ylt].
- eapply H22; eauto.
- eapply H12; eauto.
Qed.
@@ -2797,7 +2797,7 @@ Section pre_make_disjoint.
split; trivial.
unfold pre_union_of_collection.
intros [nn Hnn].
- destruct (lt_dec nn m); [ | tauto].
+ destruct (Compare_dec.lt_dec nn m); [ | tauto].
specialize (H0 _ Hnn).
lia.
- apply make_pre_collection_disjoint_in in Hn.
@@ -2870,7 +2870,7 @@ Section more_pre_props.
unfold collection_take.
split.
- intros na.
- destruct (lt_dec a n).
+ destruct (Compare_dec.lt_dec a n).
+ split; trivial.
destruct (map_nth_in_exists En (seq 0 n) pre_event_none a).
* now rewrite seq_length.
@@ -2923,8 +2923,8 @@ Section more_pre_props.
Qed.
Lemma pre_union_of_collection_lt_S {A} (E:nat->pre_event A) n :
- pre_event_equiv (pre_union_of_collection (fun y : nat => if lt_dec y (S n) then E y else pre_event_none))
- (pre_event_union (E n) (pre_union_of_collection (fun y : nat => if lt_dec y n then E y else pre_event_none))).
+ pre_event_equiv (pre_union_of_collection (fun y : nat => if Compare_dec.lt_dec y (S n) then E y else pre_event_none))
+ (pre_event_union (E n) (pre_union_of_collection (fun y : nat => if Compare_dec.lt_dec y n then E y else pre_event_none))).
Proof.
intros ?; split.
- intros [? HH].
diff --git a/coq/ProbTheory/Expectation.v b/rocq/ProbTheory/Expectation.v
similarity index 99%
rename from coq/ProbTheory/Expectation.v
rename to rocq/ProbTheory/Expectation.v
index 0974f7bc..c2bfb0dc 100644
--- a/coq/ProbTheory/Expectation.v
+++ b/rocq/ProbTheory/Expectation.v
@@ -1,4 +1,4 @@
-Require Import Reals.
+Require Import ZArith Reals.
Require Import Lra Lia.
Require Import List Permutation.
@@ -1074,7 +1074,7 @@ Section Expectation_sec.
* invcs H1.
rewrite H4 in H3.
congruence.
- + destruct (gt_dec n 0).
+ + destruct (Compare_dec.gt_dec n 0).
* generalize (find_none _ _ H2); intros.
specialize (H3 0).
rewrite <- in_rev in H3.
@@ -1095,7 +1095,7 @@ Section Expectation_sec.
generalize (pow_exp_gt 2 n); intros.
cut_to H4.
replace (0%nat) with (n*0)%nat at 1 by lia.
- apply mult_lt_compat_l; lia.
+ apply Nat.mul_lt_mono_pos_l; lia.
lia.
* assert (n = 0)%nat by lia.
invcs H3.
@@ -1142,13 +1142,13 @@ Section Expectation_sec.
clear n0.
assert (pos1:(n * 2 ^ n > 0)%nat).
{
- apply NPeano.Nat.mul_pos_pos.
+ apply Nat.mul_pos_pos.
- destruct n; try lia.
simpl in Xlt.
specialize (posX omega).
now apply Rbar_le_not_lt in posX.
- simpl.
- apply NPeano.Nat.Private_NZPow.pow_pos_nonneg
+ apply Nat.Private_NZPow.pow_pos_nonneg
; lia.
}
match_case; intros.
@@ -1191,7 +1191,7 @@ Section Expectation_sec.
simpl in HH.
rewrite app_ass in HH.
rewrite app_nth2 in HH by lia.
- rewrite NPeano.Nat.sub_diag in HH.
+ rewrite Nat.sub_diag in HH.
simpl in HH.
subst.
split; intros.
@@ -1210,7 +1210,7 @@ Section Expectation_sec.
rewrite last_app in eqq0 by congruence.
simpl in eqq0.
rewrite <- eqq0.
- rewrite NPeano.Nat.sub_1_r.
+ rewrite Nat.sub_1_r.
rewrite Nat.succ_pred_pos by trivial.
rewrite mult_INR.
simpl.
@@ -1234,7 +1234,7 @@ Section Expectation_sec.
rewrite eqq4 in Fl1.
invcs Fl1.
match_destr_in H1.
- rewrite NPeano.Nat.add_1_r in n0.
+ rewrite Nat.add_1_r in n0.
rewrite Rbar_div_Rdiv.
now apply Rbar_not_le_lt in n0.
+ rewrite app_length; simpl.
@@ -1251,8 +1251,8 @@ Section Expectation_sec.
rewrite <- Rbar_div_Rdiv in r0.
rewrite Rbar_le_div_l in r0 ; [| apply pow2_pos].
{
- destruct (lt_eq_lt_dec (length d) k) as [[lt1|]|lt1]; trivial
- ; elimtype False.
+ destruct (Compare_dec.lt_eq_lt_dec (length d) k) as [[lt1|]|lt1]; trivial
+ ; exfalso.
- generalize (f_equal (fun x => nth k x 0)%nat); intros HH.
specialize (HH _ _ eqq2).
{
@@ -1363,14 +1363,7 @@ Section Expectation_sec.
intros posX.
intros omega k.
intros klt.
- assert (pos1:(n * 2 ^ n > 0)%nat).
- {
- apply NPeano.Nat.mul_pos_pos.
- - destruct n; try lia.
- - simpl.
- apply NPeano.Nat.Private_NZPow.pow_pos_nonneg
- ; lia.
- }
+ assert (pos1:(n * 2 ^ n > 0)%nat) by lia.
unfold simple_approx.
split; intros HH.
- match_destr_in HH.
@@ -1414,7 +1407,7 @@ Section Expectation_sec.
rewrite last_app in eqq0 by congruence.
simpl in eqq0.
subst.
- rewrite NPeano.Nat.sub_1_r.
+ rewrite Nat.sub_1_r.
rewrite Nat.succ_pred_pos by trivial.
rewrite mult_INR.
unfold Rdiv.
@@ -1453,7 +1446,7 @@ Section Expectation_sec.
rewrite app_length in HH.
replace ((S (length (rev l2)) - (length (rev l2) + length [length (rev l2)])))%nat with 0%nat in HH.
* rewrite rev_nth in HH by (simpl; lia).
- rewrite plus_0_l in HH.
+ rewrite Nat.add_0_l in HH.
rewrite HH.
rewrite Forall_forall in Fl1.
specialize (Fl1 (nth (length (n1 :: l1) - 1) (n1 :: l1) 0%nat)).
@@ -1537,8 +1530,8 @@ Section Expectation_sec.
subst.
apply Rmult_eq_compat_r.
f_equal.
- destruct (lt_eq_lt_dec (length (rev l2)) k) as [[lt1|]|lt1]; trivial
- ; elimtype False.
+ destruct (Compare_dec.lt_eq_lt_dec (length (rev l2)) k) as [[lt1|]|lt1]; trivial
+ ; exfalso.
- generalize (f_equal (fun x => nth k x ((fun x : nat => INR x / 2 ^ n) 0%nat))); intros HH.
specialize (HH _ _ eqq2).
{
@@ -1991,7 +1984,7 @@ Section Expectation_sec.
apply Rinv_0_lt_compat.
apply cond_pos.
++ apply Rle_pow; [lra |].
- apply Max.le_max_l.
+ apply Nat.le_max_l.
* rewrite Rinv_involutive.
reflexivity.
apply Rgt_not_eq.
@@ -2006,7 +1999,7 @@ Section Expectation_sec.
specialize (posX ω).
now rewrite <- isfin in posX; simpl in posX.
* apply le_INR.
- apply Max.le_max_r.
+ apply Nat.le_max_r.
Qed.
Lemma simple_approx_lim_seq (X:Ts -> Rbar) (posX : Rbar_NonnegativeFunction X) :
diff --git a/coq/ProbTheory/FunctionsToReal.v b/rocq/ProbTheory/FunctionsToReal.v
similarity index 97%
rename from coq/ProbTheory/FunctionsToReal.v
rename to rocq/ProbTheory/FunctionsToReal.v
index 8cb665fd..e4ae64ef 100644
--- a/coq/ProbTheory/FunctionsToReal.v
+++ b/rocq/ProbTheory/FunctionsToReal.v
@@ -81,7 +81,7 @@ Section defs.
(fun omega => (rv_X1 omega) + (rv_X2 omega)).
Definition rvsum (Xn : nat -> Ts -> R) (n : nat) :=
- (fun omega => sum_n (fun n0 => Xn n0 omega) n).
+ (fun omega => @sum_n R_AbelianGroup (fun n0 => Xn n0 omega) n).
Definition rvscale (c:R) (rv_X : Ts -> R) :=
fun omega => c * (rv_X omega).
@@ -845,7 +845,7 @@ Section defs.
Qed.
Lemma Rbar_Rabs_lim_sum_le0 (f : nat -> Ts -> R) (x : Ts) :
- is_finite (Lim_seq (sum_n (fun n=> Rabs (f n x)))) ->
+ is_finite (Lim_seq (@sum_n R_AbelianGroup (fun n=> Rabs (f n x)))) ->
Rbar_le
(Rbar_abs (Lim_seq (fun n => (rvsum f) n x)))
(Rbar_abs (Lim_seq (fun n => (rvsum (fun n0 => (rvabs (f n0))) n x)))).
@@ -859,11 +859,11 @@ Section defs.
apply ex_series_Lim_seq in H.
apply ex_series_Lim_seq in H0.
replace (Lim_seq
- (fun n : nat => sum_n (fun n0 : nat => f n0 x) n))
+ (fun n : nat => @sum_n R_AbelianGroup (fun n0 : nat => f n0 x) n))
with (Finite ( Series (fun n : nat => f n x))).
replace (Lim_seq
(fun n : nat =>
- sum_n (fun n0 : nat => Rabs (f n0 x)) n))
+ @sum_n R_AbelianGroup (fun n0 : nat => Rabs (f n0 x)) n))
with (Finite (Series (fun n : nat => Rabs (f n x)))).
simpl.
apply Rge_le.
@@ -875,7 +875,7 @@ Section defs.
Qed.
Lemma Rabs_lim_sum_le0 (f : nat -> Ts -> R) (x : Ts) :
- is_finite (Lim_seq (sum_n (fun n=> Rabs (f n x)))) ->
+ is_finite (Lim_seq (@sum_n R_AbelianGroup (fun n=> Rabs (f n x)))) ->
Rbar_le
(Rbar_abs (Finite (real (Lim_seq (fun n => (rvsum f) n x)))))
(Rbar_abs (Lim_seq (fun n => (rvsum (fun n0 => (rvabs (f n0))) n x)))).
@@ -889,11 +889,11 @@ Section defs.
apply ex_series_Lim_seq in H.
apply ex_series_Lim_seq in H0.
replace (Lim_seq
- (fun n : nat => sum_n (fun n0 : nat => f n0 x) n))
+ (fun n : nat => @sum_n R_AbelianGroup (fun n0 : nat => f n0 x) n))
with (Finite ( Series (fun n : nat => f n x))).
replace (Lim_seq
(fun n : nat =>
- sum_n (fun n0 : nat => Rabs (f n0 x)) n))
+ @sum_n R_AbelianGroup (fun n0 : nat => Rabs (f n0 x)) n))
with (Finite (Series (fun n : nat => Rabs (f n x)))).
simpl.
apply Rge_le.
@@ -915,11 +915,11 @@ Section defs.
- rewrite <- H.
apply Rbar_Rabs_lim_sum_le0.
unfold rvsum, rvabs in H.
- replace (Lim_seq (sum_n (fun n : nat => Rabs (f n x))))
+ replace (Lim_seq (@sum_n R_AbelianGroup (fun n : nat => Rabs (f n x))))
with
(Lim_seq
(fun n : nat =>
- sum_n (fun n0 : nat => Rabs (f n0 x)) n)).
+ @sum_n R_AbelianGroup (fun n0 : nat => Rabs (f n0 x)) n)).
now rewrite H.
apply Lim_seq_ext.
intros; trivial.
@@ -949,11 +949,11 @@ Section defs.
- rewrite <- H.
apply Rabs_lim_sum_le0.
unfold rvsum, rvabs in H.
- replace (Lim_seq (sum_n (fun n : nat => Rabs (f n x))))
+ replace (Lim_seq (@sum_n R_AbelianGroup (fun n : nat => Rabs (f n x))))
with
(Lim_seq
(fun n : nat =>
- sum_n (fun n0 : nat => Rabs (f n0 x)) n)).
+ @sum_n R_AbelianGroup (fun n0 : nat => Rabs (f n0 x)) n)).
now rewrite H.
apply Lim_seq_ext.
intros; trivial.
diff --git a/coq/ProbTheory/Gaussian.v b/rocq/ProbTheory/Gaussian.v
similarity index 100%
rename from coq/ProbTheory/Gaussian.v
rename to rocq/ProbTheory/Gaussian.v
diff --git a/coq/ProbTheory/Independence.v b/rocq/ProbTheory/Independence.v
similarity index 100%
rename from coq/ProbTheory/Independence.v
rename to rocq/ProbTheory/Independence.v
diff --git a/coq/ProbTheory/Martingale.v b/rocq/ProbTheory/Martingale.v
similarity index 98%
rename from coq/ProbTheory/Martingale.v
rename to rocq/ProbTheory/Martingale.v
index ada6417a..4e7b17c7 100644
--- a/coq/ProbTheory/Martingale.v
+++ b/rocq/ProbTheory/Martingale.v
@@ -500,7 +500,7 @@ Section martingale.
FiniteExpectation prts (Y s) <= FiniteExpectation prts (Y t).
Proof.
intros s t sltt.
- destruct (le_lt_or_eq _ _ sltt).
+ destruct (proj1 (Nat.lt_eq_cases _ _) sltt).
- eapply is_sub_martingale_lt in mart; try eapply H.
assert (rv1:RandomVariable dom borel_sa (FiniteConditionalExpectation prts (sub s) (Y t))).
{
@@ -546,7 +546,7 @@ Section martingale.
cut (forall s t, (s <= t)%nat -> FiniteExpectation prts (Y s) = FiniteExpectation prts (Y t)).
{
intros.
- destruct (NPeano.Nat.le_ge_cases s t).
+ destruct (Nat.le_ge_cases s t).
- now apply H.
- symmetry; now apply H.
}
@@ -876,7 +876,7 @@ Section martingale.
end}
with
| None => right (fun x => x)
- | Some a => le_dec a n
+ | Some a => Compare_dec.le_dec a n
end).
@@ -931,7 +931,7 @@ Section martingale.
split; intros HH2.
* invcs HH2.
reflexivity.
- * apply le_n_0_eq in HH2; congruence.
+ * apply Nat.le_0_r in HH2; congruence.
+ generalize (HH (S n)); intros HH1.
generalize (HH n); intros HH2.
apply sa_complement in HH2.
@@ -954,7 +954,7 @@ Section martingale.
unfold stopping_time_pre_event, const.
apply sa_sigma_const.
destruct c.
- - destruct (le_dec n0 n); try tauto.
+ - destruct (Compare_dec.le_dec n0 n); try tauto.
- tauto.
Qed.
@@ -1341,14 +1341,14 @@ Section martingale.
intros ?.
unfold pre_event_union, pre_event_complement.
split; intros.
- - destruct (le_dec n0 n1); eauto.
+ - destruct (Compare_dec.le_dec n0 n1); eauto.
- destruct H; tauto.
}
rewrite eqq2; clear eqq2.
apply sa_union.
- apply sa_complement.
apply sa_sigma_const.
- destruct (le_dec n0 n1); tauto.
+ destruct (Compare_dec.le_dec n0 n1); tauto.
- apply stop'.
Qed.
@@ -1724,13 +1724,14 @@ Section martingale.
unfold rvsum.
destruct k; simpl.
- lra.
- - generalize (@Hierarchy.sum_n_plus Hierarchy.R_AbelianGroup
+ - generalize (@Hierarchy.sum_n_plus
+ (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup)
(fun n0 : nat => H1 (S n0) a * (X (S n0) a + -1 * X n0 a))
(fun n0 : nat => H2 (S n0) a * (X (S n0) a + -1 * X n0 a))
k); intros eqq.
unfold Hierarchy.plus in eqq; simpl in eqq.
rewrite <- eqq.
- apply (@Hierarchy.sum_n_ext Hierarchy.R_AbelianGroup); intros.
+ apply (@Hierarchy.sum_n_ext (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup)); intros.
lra.
Qed.
@@ -1741,7 +1742,7 @@ Section martingale.
unfold martingale_transform.
destruct k; trivial.
unfold rvsum; simpl.
- apply (@Hierarchy.sum_n_ext Hierarchy.R_AbelianGroup); intros.
+ apply (@Hierarchy.sum_n_ext Hierarchy.R_AbelianMonoid); intros.
rv_unfold.
now rewrite eqq1, eqq2.
Qed.
@@ -1797,7 +1798,7 @@ Section martingale.
* destruct HH as [??].
f_equal.
apply antisymmetry
- ; apply not_lt
+ ; apply Nat.le_ngt
; intros HH.
-- eapply classic_min_of_some_first in H; eauto.
-- specialize (H1 _ HH).
@@ -1811,7 +1812,7 @@ Section martingale.
- apply sa_inter.
+ apply adaptX.
+ apply sa_pre_countable_inter; intros.
- destruct (lt_dec n0 n).
+ destruct (Compare_dec.lt_dec n0 n).
* apply (sa_proper _ (fun x => ~ B (X n0 x))).
-- intros ?; tauto.
-- apply sa_complement.
@@ -1872,7 +1873,7 @@ Section martingale.
f_equal; lia.
+ apply sa_countable_union; intros.
unfold IsStoppingTime, stopping_time_pre_event in IHn.
- * destruct (lt_dec n0 a).
+ * destruct (Compare_dec.lt_dec n0 a).
-- apply sa_inter.
++ apply sa_sigma_const.
now left.
@@ -1926,7 +1927,7 @@ Section martingale.
+ split; [| congruence].
intros [?[?[??]]]; congruence.
- apply sa_countable_union; intros old.
- destruct (le_dec old n).
+ destruct (Compare_dec.le_dec old n).
+ apply sa_inter.
* eapply sa_proper; try eapply sa_all.
firstorder.
@@ -1972,7 +1973,7 @@ Section martingale.
shelve.
{
intros ?.
- rewrite plus_0_r.
+ rewrite Nat.add_0_r.
reflexivity.
}
{
@@ -1980,7 +1981,7 @@ Section martingale.
}
Unshelve.
match_destr.
- now rewrite plus_0_r.
+ now rewrite Nat.add_0_r.
Qed.
Lemma hitting_time_from_is_stop
@@ -1992,7 +1993,7 @@ Section martingale.
unfold hitting_time_from, hitting_time.
intros ?.
unfold stopping_time_pre_event.
- destruct (le_dec old n).
+ destruct (Compare_dec.le_dec old n).
- apply (sa_proper _ (fun x => B (X (n)%nat x) /\
forall k, (old <= k < n)%nat -> ~ B (X k x))).
{
@@ -2002,7 +2003,7 @@ Section martingale.
+ destruct HH as [??].
f_equal.
apply antisymmetry
- ; apply not_lt
+ ; apply Nat.le_ngt
; intros HH.
* apply (classic_min_of_some_first _ _ H (n-old)); [lia |].
now replace (n - old + old)%nat with n by lia.
@@ -2026,9 +2027,9 @@ Section martingale.
apply sa_inter.
+ apply adaptX.
+ apply sa_pre_countable_inter; intros.
- destruct (le_dec old n0).
+ destruct (Compare_dec.le_dec old n0).
{
- destruct (lt_dec n0 n).
+ destruct (Compare_dec.lt_dec n0 n).
- apply (sa_proper _ (fun x => ~ B (X n0 x))).
+ intros ?; tauto.
+ apply sa_complement.
diff --git a/coq/ProbTheory/MartingaleConvergence.v b/rocq/ProbTheory/MartingaleConvergence.v
similarity index 98%
rename from coq/ProbTheory/MartingaleConvergence.v
rename to rocq/ProbTheory/MartingaleConvergence.v
index 2fdcd2d9..51977e50 100644
--- a/coq/ProbTheory/MartingaleConvergence.v
+++ b/rocq/ProbTheory/MartingaleConvergence.v
@@ -80,7 +80,7 @@ Section mct.
Definition upcrossing_var_expr a b n ts k
:= match upcrossing_times a b (2*k) ts with
| None => 0%nat
- | Some upn => if le_dec upn n then k else 0%nat
+ | Some upn => if Compare_dec.le_dec upn n then k else 0%nat
end.
Definition upcrossing_var a b n (ts:Ts) : R
@@ -190,7 +190,7 @@ Section mct.
}
* apply sa_countable_union; intros.
{
- destruct (le_dec n0 m)%nat.
+ destruct (Compare_dec.le_dec n0 m)%nat.
- apply sa_inter.
+ generalize (upcrossing_times_is_stop a b (2 * n - 1) n0); unfold IsStoppingTime, stopping_time_pre_event.
eapply is_filtration_le; trivial.
@@ -209,7 +209,7 @@ Section mct.
split.
- match_destr.
intros HH.
- destruct (le_dec (S m) n0)%nat; trivial.
+ destruct (Compare_dec.le_dec (S m) n0)%nat; trivial.
elim HH.
eexists; split; [reflexivity |].
lia.
@@ -223,7 +223,7 @@ Section mct.
apply sa_complement.
apply sa_countable_union; intros.
{
- destruct (le_dec n0 m)%nat.
+ destruct (Compare_dec.le_dec n0 m)%nat.
- apply sa_inter.
+ generalize (upcrossing_times_is_stop a b (2 * n) n0); unfold IsStoppingTime, stopping_time_pre_event.
eapply is_filtration_le; trivial.
@@ -267,7 +267,7 @@ Section mct.
Proof.
replace (x + S x)%nat with (S (x + x))%nat by lia.
rewrite Nat.even_succ.
- rewrite <- NPeano.Nat.negb_even.
+ rewrite <- Nat.negb_even.
now rewrite plus_self_even.
Qed.
@@ -425,7 +425,7 @@ Section mct.
match_case; intros.
match_case_in IHh; intros.
+ rewrite H2 in IHh.
- eapply lt_trans.
+ eapply Nat.lt_trans.
apply IHh.
apply upcrossing_times_some with (n0 := n1) in H1; trivial; try lia.
+ apply upcrossing_times_none in H2; try lia.
@@ -466,7 +466,7 @@ Section mct.
Proof.
intros.
destruct (upcrossing_times_some_S a b (S k) a0 n1 H0); intros.
- destruct (lt_dec 0 k).
+ destruct (Compare_dec.lt_dec 0 k).
- generalize (upcrossing_times_some a b k a0 n0 x l H H1); intros.
generalize (upcrossing_times_some a b (S k) a0 x n1); intros.
cut_to H3; try lia; trivial.
@@ -630,7 +630,7 @@ Section mct.
- contrapose.
intros.
assert (m1 <= m0)%nat by lia.
- destruct (lt_dec m1 m0).
+ destruct (Compare_dec.lt_dec m1 m0).
+ generalize (upcrossing_times_monotonic_l a b a0 n1 n0 m1 m0 ); intros.
cut_to H5; trivial.
lia.
@@ -710,12 +710,12 @@ Section mct.
assert (2 * k + 1 < 2*S x)%nat by lia.
apply upcrossing_times_none_plus_alt with (kk := (2 * S x)%nat) in H1; try lia.
congruence.
- + destruct (lt_dec (2 * S x)%nat (2 * k)%nat).
+ + destruct (Compare_dec.lt_dec (2 * S x)%nat (2 * k)%nat).
* apply upcrossing_times_none_plus_alt with (kk := (2 * k)%nat) in H6; try lia.
congruence.
- * destruct (lt_dec (2 * k)%nat (2 * S x)%nat).
+ * destruct (Compare_dec.lt_dec (2 * k)%nat (2 * S x)%nat).
-- assert (2 * k + 1 <= 2 * S x - 1)%nat by lia.
- destruct (lt_dec (2 * k + 1)%nat (2 * S x - 1)%nat).
+ destruct (Compare_dec.lt_dec (2 * k + 1)%nat (2 * S x - 1)%nat).
++ apply upcrossing_times_none_plus_alt with (kk := (2 * S x - 1)%nat) in H1; try lia.
congruence.
++ assert ( 2 * k + 1 = 2 * S x - 1)%nat by lia.
@@ -794,8 +794,8 @@ Section mct.
(k2 < k)%nat.
Proof.
intros.
- destruct (le_dec k k2); try lia.
- destruct (lt_dec k k2).
+ destruct (Compare_dec.le_dec k k2); try lia.
+ destruct (Compare_dec.lt_dec k k2).
- now apply (upcrossing_times_none_plus_alt a b k k2 a0) in H; try lia.
- assert (k = k2) by lia.
now rewrite H1 in H.
@@ -829,18 +829,18 @@ Section mct.
- match_case_in H; intros; rewrite H5 in H; try easy.
match_case_in H; intros; rewrite H6 in H; try easy.
destruct H.
- destruct (lt_dec (S x) k).
+ destruct (Compare_dec.lt_dec (S x) k).
+ assert (n1 < n2)%nat.
{
specialize (H0 n1 n2 (2 * S x)%nat (2 * k)%nat).
cut_to H0; try lia; trivial.
}
lia.
- + destruct (lt_dec k (S x)).
+ + destruct (Compare_dec.lt_dec k (S x)).
* assert (2 * k + 1 <= 2 * S x - 1)%nat by lia.
assert (n3 <= n0)%nat.
{
- destruct (lt_dec (2 * k + 1)%nat (2 * S x - 1)%nat).
+ destruct (Compare_dec.lt_dec (2 * k + 1)%nat (2 * S x - 1)%nat).
- specialize (H0 n3 n0 (2 * k + 1)%nat (2 * S x - 1)%nat).
cut_to H0; try lia; trivial.
- assert (2 * k + 1 = 2 * S x - 1)%nat by lia.
@@ -868,7 +868,7 @@ Section mct.
match_case_in H; intros; rewrite H8 in H.
+ assert (n2 <= n0)%nat.
{
- destruct (lt_dec (2 * k + 1)%nat (2 * S x - 1)%nat).
+ destruct (Compare_dec.lt_dec (2 * k + 1)%nat (2 * S x - 1)%nat).
- specialize (H0 n2 n0 (2 * k + 1)%nat (2 * S x - 1)%nat).
cut_to H0; try lia; trivial.
- assert (2 * k + 1 = 2 * S x - 1)%nat by lia.
@@ -877,7 +877,7 @@ Section mct.
now invcs H3.
}
lia.
- + destruct (lt_dec (2 * k + 1)%nat (2 * S x - 1)).
+ + destruct (Compare_dec.lt_dec (2 * k + 1)%nat (2 * S x - 1)).
* now apply upcrossing_times_none_plus_alt with (kk := (2*S x - 1)%nat) in H8.
* assert (2 * k + 1 = 2 * S x - 1)%nat by lia.
rewrite H9 in H8.
@@ -921,14 +921,14 @@ Section mct.
generalize (upcrossing_times_0 a b a0 n1 n2); intros.
cut_to H8; trivial; try lia.
}
- destruct (lt_dec (S x) k).
+ destruct (Compare_dec.lt_dec (S x) k).
+ assert (n2 < n0)%nat.
{
specialize (H5 n2 n0 (2 * S x)%nat (2 * k)%nat).
cut_to H5; try lia; trivial.
}
lia.
- + destruct (lt_dec k (S x)).
+ + destruct (Compare_dec.lt_dec k (S x)).
* assert (2 * k + 1 < 2 * S x)%nat by lia.
apply (upcrossing_times_none_plus_alt a b (2 * k + 1)%nat (2 * S x)%nat a0) in H2; try lia.
congruence.
@@ -943,14 +943,14 @@ Section mct.
replace (2 * 0)%nat with 0%nat in H7 by lia.
congruence.
}
- destruct (lt_dec (S x) k).
+ destruct (Compare_dec.lt_dec (S x) k).
+ apply (upcrossing_times_none_plus_alt a b (2 * S x)%nat (2 * k)%nat a0) in H7; try lia.
congruence.
- + destruct (lt_dec k (S x)).
+ + destruct (Compare_dec.lt_dec k (S x)).
* assert (2 * k +1 <= 2 * S x - 1)%nat by lia.
assert (upcrossing_times a b (2 * S x - 1)%nat a0 = None).
{
- destruct (lt_dec (2 * k + 1)%nat (2 * S x - 1)%nat).
+ destruct (Compare_dec.lt_dec (2 * k + 1)%nat (2 * S x - 1)%nat).
- now apply (upcrossing_times_none_plus_alt a b (2 * k + 1)%nat (2 * S x - 1)%nat a0) in H2; try lia.
- now replace (2 * k + 1)%nat with (2 * S x - 1)%nat in H2 by lia.
}
@@ -1564,7 +1564,7 @@ Section mct.
Definition upcrossing_var_expr1 a b n ts k
:= match upcrossing_times a b k ts with
| None => 0%nat
- | Some upn => if le_dec upn n then k else 0%nat
+ | Some upn => if Compare_dec.le_dec upn n then k else 0%nat
end.
Lemma upcrossing_bound_transform_ge_0 a b a0 k n0 n :
@@ -1598,7 +1598,7 @@ Section mct.
match_case_in H6; intros; rewrite H8 in H6; try easy.
rewrite H8 in H7.
replace (S n - 1)%nat with n in H7 by lia.
- destruct (lt_dec n2 (S n)).
+ destruct (Compare_dec.lt_dec n2 (S n)).
+ match_case_in H3; intros; rewrite H9 in H3; rewrite H9 in H7; apply H7; try lia.
match_destr_in H3; try lia.
+ assert (S n <= n2)%nat by lia.
@@ -1612,7 +1612,7 @@ Section mct.
forall j, (upcrossing_var_expr a b (S n) a0 j <= k)%nat.
Proof.
intros upk j.
- destruct (le_dec j (S n)).
+ destruct (Compare_dec.le_dec j (S n)).
- unfold upcrossing_var, upcrossing_var_expr in *.
match_option; [| lia].
match_destr; [| lia].
@@ -1669,7 +1669,7 @@ Section mct.
(upcrossing_var_expr a b (S n) a0 x1 <= kk)%nat).
{
intros.
- destruct (le_dec x1 (S n)).
+ destruct (Compare_dec.le_dec x1 (S n)).
- apply INR_le.
subst.
apply Hin'.
@@ -1716,7 +1716,7 @@ Section mct.
generalize (upcrossing_bound_range0_init a b a0); intros.
unfold Hierarchy.sum_n.
match_case_in H1; intros; rewrite H3 in H1.
- + destruct (lt_dec n n0).
+ + destruct (Compare_dec.lt_dec n n0).
* rewrite (@Hierarchy.sum_n_m_ext_loc Hierarchy.R_AbelianGroup) with
(b := fun n1 => Hierarchy.zero).
-- rewrite Hierarchy.sum_n_m_const_zero.
@@ -1835,7 +1835,7 @@ Section mct.
unfold Hierarchy.zero; simpl; lra.
- generalize (upcrossing_bound_transform_helper a b a0 n0); intros.
match_case_in H1; intros; rewrite H3 in H1.
- + destruct (lt_dec (n1-1)%nat n).
+ + destruct (Compare_dec.lt_dec (n1-1)%nat n).
* unfold Hierarchy.sum_n.
rewrite Hierarchy.sum_n_m_Chasles with (m := (n1-1)%nat); try lia.
unfold Hierarchy.sum_n in H1.
@@ -1892,7 +1892,7 @@ Section mct.
rewrite Hierarchy.sum_n_m_zero; try lia.
unfold Hierarchy.zero; simpl.
lra.
- - destruct (le_dec 1 (upcrossing_var_expr a b (S n) a0 (S k))).
+ - destruct (Compare_dec.le_dec 1 (upcrossing_var_expr a b (S n) a0 (S k))).
+ transitivity
(@Hierarchy.sum_n_m Hierarchy.R_AbelianGroup
(fun _ => b - a)
@@ -2097,7 +2097,7 @@ Section mct.
generalize (classic_min_of_some_first _ _ eqq); simpl; unfold id; intros HHY2.
generalize (classic_min_of_some_first _ _ eqq0); simpl; unfold id; intros HHM2.
apply antisymmetry
- ; apply not_lt
+ ; apply Nat.le_ngt
; intros HH.
+ apply (HHY2 _ HH).
unfold Y, ϕ.
@@ -2131,7 +2131,7 @@ Section mct.
generalize (classic_min_of_some_first _ _ eqq); simpl; unfold id; intros HHY2.
generalize (classic_min_of_some_first _ _ eqq0); simpl; unfold id; intros HHM2.
apply antisymmetry
- ; apply not_lt
+ ; apply Nat.le_ngt
; intros HH.
+ apply (HHY2 _ HH).
unfold Y, ϕ.
@@ -2797,7 +2797,7 @@ Section mct.
destruct (Qs_between_Rbars _ _ r) as [a [b [age [ab blt]]]].
specialize (H0 a b ab).
destruct (is_finite_witness _ H0) as [nmax eqq].
- elimtype False.
+ exfalso.
unfold Rbar_rvlim in eqq.
generalize (Elim_seq_incr_elem
diff --git a/coq/ProbTheory/MartingaleStopped.v b/rocq/ProbTheory/MartingaleStopped.v
similarity index 98%
rename from coq/ProbTheory/MartingaleStopped.v
rename to rocq/ProbTheory/MartingaleStopped.v
index af8e7810..869babac 100644
--- a/coq/ProbTheory/MartingaleStopped.v
+++ b/rocq/ProbTheory/MartingaleStopped.v
@@ -69,10 +69,10 @@ Section stopped_process.
unfold lift1_min.
rv_unfold; unfold rvsum.
destruct n; match_option.
- - destruct (Min.min_dec (S n) n0).
+ - destruct (Nat.min_dec (S n) n0).
+ assert (nle: (S n <= n0)%nat) by lia.
rewrite e.
- rewrite (@Hierarchy.sum_n_ext_loc Hierarchy.R_AbelianGroup _ (fun _ => 0)).
+ rewrite (@Hierarchy.sum_n_ext_loc (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) _ (fun _ => 0)).
* rewrite sum_n_zero.
field_simplify.
match_destr; try lra.
@@ -93,7 +93,7 @@ Section stopped_process.
rewrite eqq in p.
assert (n0 = S n) by lia.
subst.
- rewrite (@Hierarchy.sum_n_ext_loc Hierarchy.R_AbelianGroup _ (fun _ => 0)).
+ rewrite (@Hierarchy.sum_n_ext_loc (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) _ (fun _ => 0)).
-- rewrite sum_n_zero.
lra.
-- intros.
@@ -116,7 +116,7 @@ Section stopped_process.
elim n.
now red.
- rewrite Hierarchy.sum_Sn.
- destruct (le_lt_dec n2 n).
+ destruct (Compare_dec.le_lt_dec n2 n).
+ specialize (IHn l).
rewrite <- IHn.
unfold Hierarchy.plus; simpl.
@@ -126,7 +126,7 @@ Section stopped_process.
lia.
+ assert (n2 = S n) by lia.
subst.
- rewrite (@Hierarchy.sum_n_ext_loc Hierarchy.R_AbelianGroup _ (fun _ => 0)).
+ rewrite (@Hierarchy.sum_n_ext_loc Hierarchy.R_AbelianMonoid _ (fun _ => 0)).
-- rewrite sum_n_zero.
unfold Hierarchy.plus; simpl.
match_destr; try lra.
@@ -137,7 +137,7 @@ Section stopped_process.
assert (S n = n0) by congruence.
lia.
}
- - rewrite (@Hierarchy.sum_n_ext Hierarchy.R_AbelianGroup _ (fun _ => 0)).
+ - rewrite (@Hierarchy.sum_n_ext Hierarchy.R_AbelianMonoid _ (fun _ => 0)).
+ rewrite sum_n_zero.
field_simplify.
match_destr; try lra.
@@ -689,7 +689,7 @@ Section stopped_process.
unfold Hierarchy.plus; simpl.
destruct (Nat.eq_dec n (S n0)).
* subst.
- assert (0 <= @Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n0 : nat => Rabs (Y n0 x)) n0).
+ assert (0 <= @Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n0 : nat => Rabs (Y n0 x)) n0).
{
apply sum_n_nneg; intros.
apply Rabs_pos.
@@ -1078,7 +1078,7 @@ Section stopped_process.
Qed.
Lemma Rabs_sum_n_triang f n :
- Rabs (@Hierarchy.sum_n Hierarchy.R_AbelianGroup f n) <= @Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun k => Rabs (f k)) n.
+ Rabs (@Hierarchy.sum_n Hierarchy.R_AbelianMonoid f n) <= @Hierarchy.sum_n Hierarchy.R_AbelianMonoid (fun k => Rabs (f k)) n.
Proof.
induction n.
- now repeat rewrite Hierarchy.sum_O.
@@ -1181,7 +1181,7 @@ Section stopped_process.
unfold rvsum.
simpl.
rewrite Rabs_sum_n_triang.
- transitivity (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun _ => K) n0).
+ transitivity (@Hierarchy.sum_n Hierarchy.R_AbelianMonoid (fun _ => K) n0).
{
apply sum_n_le_loc; intros.
replace (Y (S n1) x + -1 * Y n1 x)
diff --git a/coq/ProbTheory/Measures.v b/rocq/ProbTheory/Measures.v
similarity index 99%
rename from coq/ProbTheory/Measures.v
rename to rocq/ProbTheory/Measures.v
index 508d6f83..633777d8 100644
--- a/coq/ProbTheory/Measures.v
+++ b/rocq/ProbTheory/Measures.v
@@ -31,7 +31,7 @@ Section measure.
Class is_measure (μ:event σ -> Rbar)
:= mk_measure {
- measure_proper :> Proper (event_equiv ==> eq) μ
+ measure_proper ::> Proper (event_equiv ==> eq) μ
; measure_none : μ event_none = 0%R
; measure_nneg a : Rbar_le 0 (μ a)
; measure_countable_disjoint_union (B:nat->event σ) :
@@ -283,7 +283,7 @@ Section outer_measure.
Class is_outer_measure (μ:pre_event T -> Rbar)
:= mk_outer_measure {
- outer_measure_proper :> Proper (pre_event_equiv ==> eq) μ
+ outer_measure_proper ::> Proper (pre_event_equiv ==> eq) μ
; outer_measure_none : μ pre_event_none = 0%R
; outer_measure_nneg a : Rbar_le 0 (μ a)
; outer_measure_countable_subadditive (A:pre_event T) (B:nat->pre_event T) :
@@ -295,7 +295,7 @@ Section outer_measure.
:= mk_outer_measure_alt {
outer_measure_alt_none : μ pre_event_none = 0%R
; outer_measure_alt_nneg a : Rbar_le 0 (μ a)
- ; outer_measure_alt_monotone :> Proper (pre_event_sub ==> Rbar_le) μ
+ ; outer_measure_alt_monotone ::> Proper (pre_event_sub ==> Rbar_le) μ
; outer_measure_alt_countable_union (B:nat->pre_event T) :
Rbar_le (μ (pre_union_of_collection B)) (ELim_seq (fun i : nat => sum_Rbar_n (fun n : nat => μ (B n)) i))
}.
@@ -1597,7 +1597,7 @@ Section premeasure.
(* we could generalize events, but that is too much work for now :-) *)
Class is_premeasure (λ:alg_set Alg -> Rbar)
:= mk_premeasure {
- premeasure_proper :> Proper (alg_equiv ==> eq) λ
+ premeasure_proper ::> Proper (alg_equiv ==> eq) λ
; premeasure_none : λ alg_none = 0%R
; premeasure_nneg a : Rbar_le 0 (λ a)
; premeasure_countable_disjoint_union (B:nat->alg_set Alg) :
@@ -1714,7 +1714,7 @@ Section premeasure.
apply alg_make_collection_disjoint_in in e2.
destruct e1 as [H11 H12].
destruct e2 as [H21 H22].
- destruct (not_eq _ _ xyneq) as [xlt|ylt].
+ destruct (Compare_dec.not_eq _ _ xyneq) as [xlt|ylt].
- eapply H22; eauto.
- eapply H12; eauto.
Qed.
@@ -2497,7 +2497,7 @@ Section semi_premeasure.
Class is_semipremeasure (λ:salg_set SAlg -> Rbar)
:= mk_semipremeasure {
- semipremeasure_proper :> Proper (salg_equiv ==> eq) λ
+ semipremeasure_proper ::> Proper (salg_equiv ==> eq) λ
; semipremeasure_nneg a : Rbar_le 0 (λ a)
; semipremeasure_list_disjoint_union (B:list (salg_set SAlg)) :
diff --git a/coq/ProbTheory/OrthoProject.v b/rocq/ProbTheory/OrthoProject.v
similarity index 99%
rename from coq/ProbTheory/OrthoProject.v
rename to rocq/ProbTheory/OrthoProject.v
index 675da36d..907f2bf7 100644
--- a/coq/ProbTheory/OrthoProject.v
+++ b/rocq/ProbTheory/OrthoProject.v
@@ -371,9 +371,9 @@ Section ortho_project.
exists (max NP NQ); intros n nN.
split.
* apply HP.
- generalize (Max.le_max_l NP NQ); lia.
+ generalize (Nat.le_max_l NP NQ); lia.
* apply HQ.
- generalize (Max.le_max_r NP NQ); lia.
+ generalize (Nat.le_max_r NP NQ); lia.
+ intros P Q Himp [N HP].
exists N; intros n nN.
auto.
diff --git a/coq/ProbTheory/ProbSpace.v b/rocq/ProbTheory/ProbSpace.v
similarity index 99%
rename from coq/ProbTheory/ProbSpace.v
rename to rocq/ProbTheory/ProbSpace.v
index 28f830d5..c9aac906 100644
--- a/coq/ProbTheory/ProbSpace.v
+++ b/rocq/ProbTheory/ProbSpace.v
@@ -32,7 +32,7 @@ Definition sum_of_probs_equals {T:Type} {σ:SigmaAlgebra T}
Class ProbSpace {T:Type} (σ:SigmaAlgebra T) :=
{
ps_P : event σ -> R;
- ps_proper :> Proper (event_equiv ==> eq) ps_P ;
+ ps_proper ::> Proper (event_equiv ==> eq) ps_P ;
ps_countable_disjoint_union (collection: nat -> event σ) :
(* Assume: collection is a subset of Sigma and its elements are pairwise disjoint. *)
@@ -292,7 +292,7 @@ Qed.
Definition make_collection_disjoint {T:Type} {σ:SigmaAlgebra T} (coll:nat->event σ) : nat -> event σ
:= fun x => coll x \ (union_of_collection (fun y =>
- if lt_dec y x
+ if Compare_dec.lt_dec y x
then coll y
else ∅)).
@@ -328,13 +328,13 @@ Proof.
intros y ylt cy.
apply H2.
exists y.
- destruct (lt_dec y x); intuition.
+ destruct (Compare_dec.lt_dec y x); intuition.
- intros [ce fce].
unfold make_collection_disjoint.
split; trivial.
unfold union_of_collection.
intros [n Hn].
- destruct (lt_dec n x); trivial.
+ destruct (Compare_dec.lt_dec n x); trivial.
eapply fce; eauto.
Qed.
@@ -346,7 +346,7 @@ Proof.
apply make_collection_disjoint_in in e2.
destruct e1 as [H11 H12].
destruct e2 as [H21 H22].
- destruct (not_eq _ _ xyneq) as [xlt|ylt].
+ destruct (Compare_dec.not_eq _ _ xyneq) as [xlt|ylt].
- eapply H22; eauto.
- eapply H12; eauto.
Qed.
@@ -395,7 +395,7 @@ Section classic.
split; trivial.
unfold union_of_collection.
intros [nn Hnn].
- destruct (lt_dec nn m); [ | tauto].
+ destruct (Compare_dec.lt_dec nn m); [ | tauto].
specialize (H0 _ Hnn).
lia.
- apply make_collection_disjoint_in in Hn.
@@ -478,7 +478,7 @@ Section ascending.
replace m with (0%nat) by lia.
reflexivity.
- intros.
- apply le_lt_or_eq in H.
+ apply Nat.lt_eq_cases in H.
destruct H.
+ red in asc.
rewrite <- asc.
@@ -522,7 +522,7 @@ Section ascending.
* now apply make_collection_disjoint_sub.
+ red.
unfold make_collection_disjoint.
- destruct (classic (proj1_sig (union_of_collection (fun y : nat => if lt_dec y (S n) then En y else event_none)) a)).
+ destruct (classic (proj1_sig (union_of_collection (fun y : nat => if Compare_dec.lt_dec y (S n) then En y else event_none)) a)).
* destruct H as [x HH2].
match_destr_in HH2; [ | red in HH2; tauto].
left.
diff --git a/coq/ProbTheory/ProductSpace.v b/rocq/ProbTheory/ProductSpace.v
similarity index 99%
rename from coq/ProbTheory/ProductSpace.v
rename to rocq/ProbTheory/ProductSpace.v
index f89152c9..55aa6516 100644
--- a/coq/ProbTheory/ProductSpace.v
+++ b/rocq/ProbTheory/ProductSpace.v
@@ -2983,8 +2983,8 @@ Qed.
- apply fst_rv.
- generalize compose_rv; intros HH.
cut (
- RandomVariable (product_sa s (ivector_sa i)) (ivector_nth idx (lt_S_n idx n idx_lt) i)
- (ivector_nth idx (lt_S_n idx n idx_lt) ∘ snd)).
+ RandomVariable (product_sa s (ivector_sa i)) (ivector_nth idx (proj2 (Nat.succ_lt_mono idx n) idx_lt) i)
+ (ivector_nth idx (proj2 (Nat.succ_lt_mono idx n) idx_lt) ∘ snd)).
{
apply RandomVariable_proper; try reflexivity.
now intros [??].
@@ -3011,12 +3011,12 @@ Qed.
Proof.
intros.
destruct n; try lia.
- assert (RandomVariable (ivector_sa (ivector_const (S n) σ)) σ (fun x : ivector T (S n) => ivector_nth idx2 (lt_S_n idx2 n pf2) (ivector_tl x))).
+ assert (RandomVariable (ivector_sa (ivector_const (S n) σ)) σ (fun x : ivector T (S n) => ivector_nth idx2 (proj2 (Nat.succ_lt_mono idx2 n) pf2) (ivector_tl x))).
{
generalize (compose_rv (dom1 := (ivector_sa (ivector_const (S n) σ)))
(dom2 := (ivector_sa (ivector_const n σ)))
ivector_tl
- (fun x => ivector_nth idx2 (lt_S_n idx2 n pf2) x)); intros.
+ (fun x => ivector_nth idx2 (proj2 (Nat.succ_lt_mono idx2 n) pf2) x)); intros.
apply H; typeclasses eauto.
}
assert (independent_rvs (ivector_ps ivec_ps) σ σ ivector_hd
@@ -3025,7 +3025,7 @@ Qed.
generalize (independent_rv_compose
(ivector_ps ivec_ps) σ (ivector_sa (ivector_const n σ)) σ σ
ivector_hd ivector_tl
- (fun x => x) (fun x => ivector_nth idx2 (lt_S_n idx2 n pf2) x)
+ (fun x => x) (fun x => ivector_nth idx2 (proj2 (Nat.succ_lt_mono idx2 n) pf2) x)
); intros.
cut_to H0.
- revert H0.
@@ -3059,14 +3059,14 @@ Qed.
destruct idx2; [lia |].
destruct idx1.
+ apply (ivector_nth_independent_rvs_0 (n:=S n) (p,i) idx2).
- + generalize (IHn i idx1 idx2 (lt_S_n idx1 n pf1) (lt_S_n idx2 n pf2) (lt_S_n idx1 idx2 H)).
+ + generalize (IHn i idx1 idx2 (proj2 (Nat.succ_lt_mono idx1 n) pf1) (proj2 (Nat.succ_lt_mono idx2 n) pf2) (proj2 (Nat.succ_lt_mono idx1 idx2) H)).
unfold independent_rvs, independent_events; intros HH A B.
specialize (HH A B).
etransitivity; [| etransitivity; [apply HH |]].
* generalize (product_sa_product p (ivector_ps i)
Ω
- (rv_preimage (fun tl => ivector_nth idx1 (lt_S_n idx1 n pf1) tl) A
- ∩ rv_preimage (fun tl => ivector_nth idx2 (lt_S_n idx2 n pf2) tl) B)); intros HH2.
+ (rv_preimage (fun tl => ivector_nth idx1 (proj2 (Nat.succ_lt_mono idx1 n) pf1) tl) A
+ ∩ rv_preimage (fun tl => ivector_nth idx2 (proj2 (Nat.succ_lt_mono idx2 n) pf2) tl) B)); intros HH2.
rewrite ps_one, Rmult_1_l in HH2.
rewrite <- HH2.
apply ps_proper; intros [??]; simpl.
@@ -3074,14 +3074,14 @@ Qed.
* { f_equal.
- generalize (product_sa_product p (ivector_ps i)
Ω
- (rv_preimage (fun tl => ivector_nth idx1 (lt_S_n idx1 n pf1) tl) A)); intros HH2.
+ (rv_preimage (fun tl => ivector_nth idx1 (proj2 (Nat.succ_lt_mono idx1 n) pf1) tl) A)); intros HH2.
rewrite ps_one, Rmult_1_l in HH2.
rewrite <- HH2.
apply ps_proper; intros [??]; simpl.
unfold pre_Ω, event_preimage, pre_event_inter; tauto.
- generalize (product_sa_product p (ivector_ps i)
Ω
- (rv_preimage (fun tl => ivector_nth idx2 (lt_S_n idx2 n pf2) tl) B)); intros HH2.
+ (rv_preimage (fun tl => ivector_nth idx2 (proj2 (Nat.succ_lt_mono idx2 n) pf2) tl) B)); intros HH2.
rewrite ps_one, Rmult_1_l in HH2.
rewrite <- HH2.
apply ps_proper; intros [??]; simpl.
@@ -4419,7 +4419,7 @@ Section ps_sequence_product.
simpl.
unfold inf_cylinder_event, section_seq_event in e0.
unfold inf_cylinder_event in H3.
- pose (w := fun i => match lt_dec i N with
+ pose (w := fun i => match Compare_dec.lt_dec i N with
| left pf => ivector_nth i pf x4
| right _ => inh
end).
diff --git a/coq/ProbTheory/ProductSpaceDep.v b/rocq/ProbTheory/ProductSpaceDep.v
similarity index 100%
rename from coq/ProbTheory/ProductSpaceDep.v
rename to rocq/ProbTheory/ProductSpaceDep.v
diff --git a/coq/ProbTheory/RandomVariable.v b/rocq/ProbTheory/RandomVariable.v
similarity index 100%
rename from coq/ProbTheory/RandomVariable.v
rename to rocq/ProbTheory/RandomVariable.v
diff --git a/coq/ProbTheory/RandomVariableFinite.v b/rocq/ProbTheory/RandomVariableFinite.v
similarity index 99%
rename from coq/ProbTheory/RandomVariableFinite.v
rename to rocq/ProbTheory/RandomVariableFinite.v
index fc87871e..db66c975 100644
--- a/coq/ProbTheory/RandomVariableFinite.v
+++ b/rocq/ProbTheory/RandomVariableFinite.v
@@ -688,11 +688,11 @@ Section fe.
(Xn : nat -> Ts -> R)
(Xn_pos : forall n, NonnegativeFunction (Xn n))
(is_fin_lim :
- forall omega, is_finite (Lim_seq (sum_n (fun n => Xn n omega)))):
- NonnegativeFunction (fun omega => Lim_seq (sum_n (fun n => Xn n omega))).
+ forall omega, is_finite (Lim_seq (@sum_n R_AbelianGroup (fun n => Xn n omega)))):
+ NonnegativeFunction (fun omega => Lim_seq (@sum_n R_AbelianGroup (fun n => Xn n omega))).
Proof.
unfold NonnegativeFunction in *; intros.
- generalize (Lim_seq_pos (sum_n (fun n : nat => Xn n x))).
+ generalize (Lim_seq_pos (@sum_n R_AbelianGroup (fun n : nat => Xn n x))).
rewrite <- is_fin_lim; simpl.
intros; apply H.
intros.
@@ -726,7 +726,7 @@ Section fe.
(Xn_rv : forall n, RandomVariable dom borel_sa (Xn n))
(isfe : forall n, IsFiniteExpectation (Xn n)) :
forall (n:nat),
- sum_n (fun n0 : nat => FiniteExpectation (Xn n0)) n =
+ @sum_n R_AbelianGroup (fun n0 : nat => FiniteExpectation (Xn n0)) n =
FiniteExpectation (rvsum Xn n).
Proof.
intros.
@@ -883,8 +883,8 @@ Lemma Fatou_FiniteExpectation
apply H.
intros.
unfold rvsum, sum_n.
- replace (sum_n_m (fun n1 : nat => Xn n1 x) 0 n0) with
- (sum_n_m (fun n1 : nat => Xn n1 x) 0 n0 + 0) by lra.
+ replace (@sum_n_m R_AbelianGroup (fun n1 : nat => Xn n1 x) 0 n0) with
+ (@sum_n_m R_AbelianGroup (fun n1 : nat => Xn n1 x) 0 n0 + 0) by lra.
rewrite sum_n_Sm; [|lia].
unfold plus; simpl.
apply Rplus_le_compat_l.
diff --git a/coq/ProbTheory/RandomVariableL2.v b/rocq/ProbTheory/RandomVariableL2.v
similarity index 100%
rename from coq/ProbTheory/RandomVariableL2.v
rename to rocq/ProbTheory/RandomVariableL2.v
diff --git a/coq/ProbTheory/RandomVariableLinf.v b/rocq/ProbTheory/RandomVariableLinf.v
similarity index 99%
rename from coq/ProbTheory/RandomVariableLinf.v
rename to rocq/ProbTheory/RandomVariableLinf.v
index 8815c594..7b3ec206 100644
--- a/coq/ProbTheory/RandomVariableLinf.v
+++ b/rocq/ProbTheory/RandomVariableLinf.v
@@ -1537,14 +1537,20 @@ Qed.
LiRRVq_simpl.
apply LiRRV_plus_inv.
Qed.
-
- Definition LiRRVq_AbelianGroup_mixin : AbelianGroup.mixin_of LiRRVq
- := AbelianGroup.Mixin LiRRVq LiRRVq_plus LiRRVq_opp LiRRVq_zero
+
+ Definition LiRRVq_AbelianMonoid_mixin : AbelianMonoid.mixin_of LiRRVq
+ := AbelianMonoid.Mixin LiRRVq LiRRVq_plus LiRRVq_zero
LiRRVq_plus_comm LiRRVq_plus_assoc
- LiRRVq_plus_zero LiRRVq_plus_inv.
+ LiRRVq_plus_zero.
+
+ Canonical LiRRVq_AbelianMonoid :=
+ AbelianMonoid.Pack LiRRVq LiRRVq_AbelianMonoid_mixin LiRRVq.
+
+ Definition LiRRVq_AbelianGroup_mixin : AbelianGroup.mixin_of LiRRVq_AbelianMonoid
+ := AbelianGroup.Mixin LiRRVq_AbelianMonoid LiRRVq_opp LiRRVq_plus_inv.
Canonical LiRRVq_AbelianGroup :=
- AbelianGroup.Pack LiRRVq LiRRVq_AbelianGroup_mixin LiRRVq.
+ AbelianGroup.Pack LiRRVq (AbelianGroup.Class _ (LiRRVq_AbelianMonoid_mixin) LiRRVq_AbelianGroup_mixin) LiRRVq.
Ltac LiRRVq_simpl ::=
repeat match goal with
@@ -1590,7 +1596,7 @@ Qed.
LiRRVq_scale_plus_l LiRRVq_scale_plus_r.
Canonical LiRRVq_ModuleSpace :=
- ModuleSpace.Pack R_Ring LiRRVq (ModuleSpace.Class R_Ring LiRRVq LiRRVq_AbelianGroup_mixin LiRRVq_ModuleSpace_mixin) LiRRVq.
+ ModuleSpace.Pack R_Ring LiRRVq (ModuleSpace.Class R_Ring LiRRVq _ LiRRVq_ModuleSpace_mixin) LiRRVq.
Definition LiRRVq_norm : LiRRVq -> R
:= quot_rec LiRRV_norm_proper.
diff --git a/coq/ProbTheory/RandomVariableLpNat.v b/rocq/ProbTheory/RandomVariableLpNat.v
similarity index 98%
rename from coq/ProbTheory/RandomVariableLpNat.v
rename to rocq/ProbTheory/RandomVariableLpNat.v
index ccf6d690..dd97852e 100644
--- a/coq/ProbTheory/RandomVariableLpNat.v
+++ b/rocq/ProbTheory/RandomVariableLpNat.v
@@ -343,9 +343,9 @@ Section Lp.
* unfold Binomial.C.
left; unfold Rdiv.
apply Rmult_lt_0_compat.
- -- apply lt_0_INR; apply lt_O_fact.
+ -- apply lt_0_INR; apply Factorial.lt_O_fact.
-- apply Rinv_0_lt_compat.
- apply Rmult_lt_0_compat; apply lt_0_INR; apply lt_O_fact.
+ apply Rmult_lt_0_compat; apply lt_0_INR; apply Factorial.lt_O_fact.
* destruct (Rle_dec y x).
-- replace (p) with (i + (p-i))%nat at 2.
++ rewrite Rdef_pow_add.
@@ -573,8 +573,8 @@ Section Lp.
Record LpRRV : Type
:= LpRRV_of {
LpRRV_rv_X :> Ts -> R
- ; LpRRV_rv :> RandomVariable dom borel_sa LpRRV_rv_X
- ; LpRRV_lp :> IsLp p LpRRV_rv_X
+ ; LpRRV_rv ::> RandomVariable dom borel_sa LpRRV_rv_X
+ ; LpRRV_lp ::> IsLp p LpRRV_rv_X
}.
Global Existing Instance LpRRV_rv.
@@ -869,13 +869,19 @@ Section Lp.
apply LpRRV_plus_inv.
Qed.
- Definition LpRRVq_AbelianGroup_mixin : AbelianGroup.mixin_of LpRRVq
- := AbelianGroup.Mixin LpRRVq LpRRVq_plus LpRRVq_opp LpRRVq_zero
+ Definition LpRRVq_AbelianMonoid_mixin : AbelianMonoid.mixin_of LpRRVq
+ := AbelianMonoid.Mixin LpRRVq LpRRVq_plus LpRRVq_zero
LpRRVq_plus_comm LpRRVq_plus_assoc
- LpRRVq_plus_zero LpRRVq_plus_inv.
+ LpRRVq_plus_zero.
+
+ Canonical LpRRVq_AbelianMonoid :=
+ AbelianMonoid.Pack LpRRVq LpRRVq_AbelianMonoid_mixin LpRRVq.
+
+ Definition LpRRVq_AbelianGroup_mixin : AbelianGroup.mixin_of LpRRVq_AbelianMonoid
+ := AbelianGroup.Mixin LpRRVq_AbelianMonoid LpRRVq_opp LpRRVq_plus_inv.
Canonical LpRRVq_AbelianGroup :=
- AbelianGroup.Pack LpRRVq LpRRVq_AbelianGroup_mixin LpRRVq.
+ AbelianGroup.Pack LpRRVq (AbelianGroup.Class _ (LpRRVq_AbelianMonoid_mixin) LpRRVq_AbelianGroup_mixin) LpRRVq.
Ltac LpRRVq_simpl ::=
@@ -922,7 +928,7 @@ Section Lp.
LpRRVq_scale_plus_l LpRRVq_scale_plus_r.
Canonical LpRRVq_ModuleSpace :=
- ModuleSpace.Pack R_Ring LpRRVq (ModuleSpace.Class R_Ring LpRRVq LpRRVq_AbelianGroup_mixin LpRRVq_ModuleSpace_mixin) LpRRVq.
+ ModuleSpace.Pack R_Ring LpRRVq (ModuleSpace.Class R_Ring LpRRVq _ LpRRVq_ModuleSpace_mixin) LpRRVq.
End quot.
@@ -1014,7 +1020,7 @@ Section Lp.
Proof.
repeat rewrite Rsqr_pow2.
repeat rewrite <- pow_mult.
- now rewrite mult_comm.
+ now rewrite Nat.mul_comm.
Qed.
Lemma pow_incr_inv (x y:R) (n : nat) :
diff --git a/coq/ProbTheory/RandomVariableLpR.v b/rocq/ProbTheory/RandomVariableLpR.v
similarity index 99%
rename from coq/ProbTheory/RandomVariableLpR.v
rename to rocq/ProbTheory/RandomVariableLpR.v
index 7824b88b..d30038d5 100644
--- a/coq/ProbTheory/RandomVariableLpR.v
+++ b/rocq/ProbTheory/RandomVariableLpR.v
@@ -469,8 +469,8 @@ Qed.
Record LpRRV : Type
:= LpRRV_of {
LpRRV_rv_X :> Ts -> R
- ; LpRRV_rv :> RandomVariable dom borel_sa LpRRV_rv_X
- ; LpRRV_lp :> IsLp p LpRRV_rv_X
+ ; LpRRV_rv ::> RandomVariable dom borel_sa LpRRV_rv_X
+ ; LpRRV_lp ::> IsLp p LpRRV_rv_X
}.
Global Existing Instance LpRRV_rv.
@@ -886,13 +886,19 @@ Qed.
apply LpRRV_plus_inv.
Qed.
- Definition LpRRVq_AbelianGroup_mixin : AbelianGroup.mixin_of (LpRRVq p)
- := AbelianGroup.Mixin (LpRRVq p) LpRRVq_plus LpRRVq_opp LpRRVq_zero
+ Definition LpRRVq_AbelianMonoid_mixin : AbelianMonoid.mixin_of (LpRRVq p)
+ := AbelianMonoid.Mixin (LpRRVq p) LpRRVq_plus LpRRVq_zero
LpRRVq_plus_comm LpRRVq_plus_assoc
- LpRRVq_plus_zero LpRRVq_plus_inv.
+ LpRRVq_plus_zero.
+
+ Canonical LpRRVq_AbelianMonoid :=
+ AbelianMonoid.Pack (LpRRVq p) LpRRVq_AbelianMonoid_mixin (LpRRVq p).
+
+ Definition LpRRVq_AbelianGroup_mixin : AbelianGroup.mixin_of LpRRVq_AbelianMonoid
+ := AbelianGroup.Mixin LpRRVq_AbelianMonoid LpRRVq_opp LpRRVq_plus_inv.
Canonical LpRRVq_AbelianGroup :=
- AbelianGroup.Pack (LpRRVq p) LpRRVq_AbelianGroup_mixin (LpRRVq p).
+ AbelianGroup.Pack (LpRRVq p) (AbelianGroup.Class _ (LpRRVq_AbelianMonoid_mixin) LpRRVq_AbelianGroup_mixin) (LpRRVq p).
Ltac LpRRVq_simpl ::=
repeat match goal with
@@ -938,7 +944,7 @@ Qed.
LpRRVq_scale_plus_l LpRRVq_scale_plus_r.
Canonical LpRRVq_ModuleSpace :=
- ModuleSpace.Pack R_Ring (LpRRVq p) (ModuleSpace.Class R_Ring (LpRRVq p) LpRRVq_AbelianGroup_mixin LpRRVq_ModuleSpace_mixin) (LpRRVq p).
+ ModuleSpace.Pack R_Ring (LpRRVq p) (ModuleSpace.Class R_Ring (LpRRVq p) _ LpRRVq_ModuleSpace_mixin) (LpRRVq p).
End quotnneg.
End packednonneg.
@@ -1409,7 +1415,7 @@ Qed.
sum_n_m (fun k => pow c k) (S n) m = (pow c (S m) - pow c (S n))/(c-1).
Proof.
intros.
- rewrite sum_n_m_sum_n; [|lia].
+ rewrite (@sum_n_m_sum_n R_AbelianGroup); [|lia].
rewrite sum_geom; trivial.
rewrite sum_geom; trivial.
unfold minus, plus, opp; simpl.
diff --git a/coq/ProbTheory/RbarExpectation.v b/rocq/ProbTheory/RbarExpectation.v
similarity index 99%
rename from coq/ProbTheory/RbarExpectation.v
rename to rocq/ProbTheory/RbarExpectation.v
index 43ad83df..4c02d671 100644
--- a/coq/ProbTheory/RbarExpectation.v
+++ b/rocq/ProbTheory/RbarExpectation.v
@@ -3673,8 +3673,8 @@ Theorem Dominated_convergence
erewrite (Rbar_NonnegExpectation_ext _ _ H7) in le2.
erewrite (Rbar_NonnegExpectation_ext _ _ H8) in le3.
- rewrite (Rbar_FiniteExpectation_Rbar_NonnegExpectation _) in le2.
- rewrite (Rbar_FiniteExpectation_Rbar_NonnegExpectation _) in le3.
+ erewrite (Rbar_FiniteExpectation_Rbar_NonnegExpectation _) in le2.
+ erewrite (Rbar_FiniteExpectation_Rbar_NonnegExpectation _) in le3.
rewrite (ELimInf_proper _ (fun n => (Rbar_FiniteExpectation g) +
(Rbar_FiniteExpectation (fn n)))) in le2.
@@ -3772,6 +3772,10 @@ Theorem Dominated_convergence
-- apply is_Elim_seq_const.
-- apply lim1.
-- destruct (f x); reflexivity.
+ + eauto.
+ + eauto.
+ + eauto.
+ + eauto.
Qed.
Theorem Dominated_convergence_almost
@@ -4543,7 +4547,7 @@ Section rv_expressible.
rewrite <- H4.
apply is_lim_seq_spec in H1.
destruct (H1 M).
- destruct (le_dec x1 n).
+ destruct (Compare_dec.le_dec x1 n).
* now apply H5.
* assert (n <= x1)%nat by lia.
apply Rle_lt_trans with (r2 := simple_approx Y x1 x).
diff --git a/coq/ProbTheory/RealRandomVariable.v b/rocq/ProbTheory/RealRandomVariable.v
similarity index 100%
rename from coq/ProbTheory/RealRandomVariable.v
rename to rocq/ProbTheory/RealRandomVariable.v
diff --git a/coq/ProbTheory/RealVectorHilbert.v b/rocq/ProbTheory/RealVectorHilbert.v
similarity index 98%
rename from coq/ProbTheory/RealVectorHilbert.v
rename to rocq/ProbTheory/RealVectorHilbert.v
index 9891d148..63c23d53 100644
--- a/coq/ProbTheory/RealVectorHilbert.v
+++ b/rocq/ProbTheory/RealVectorHilbert.v
@@ -333,13 +333,19 @@ Section Rvector_defs.
apply Rvector_inv_plus.
Qed.
- Definition Rvector_AbelianGroup_mixin : AbelianGroup.mixin_of (vector R n)
- := AbelianGroup.Mixin (vector R n) Rvector_plus Rvector_opp Rvector_zero
+ Definition Rvector_AbelianMonoid_mixin : AbelianMonoid.mixin_of (vector R n)
+ := AbelianMonoid.Mixin (vector R n) Rvector_plus Rvector_zero
Rvector_plus_comm Rvector_plus_assoc
- Rvector_plus_zero Rvector_plus_inv.
+ Rvector_plus_zero.
+
+ Canonical Rvector_AbelianMonoid :=
+ AbelianMonoid.Pack (vector R n) Rvector_AbelianMonoid_mixin (vector R n).
+
+ Definition Rvector_AbelianGroup_mixin : AbelianGroup.mixin_of (Rvector_AbelianMonoid)
+ := AbelianGroup.Mixin Rvector_AbelianMonoid Rvector_opp Rvector_plus_inv.
Canonical Rvector_AbelianGroup :=
- AbelianGroup.Pack (vector R n) Rvector_AbelianGroup_mixin (vector R n).
+ AbelianGroup.Pack (vector R n) (AbelianGroup.Class _ (Rvector_AbelianMonoid_mixin) Rvector_AbelianGroup_mixin) (vector R n).
Lemma Rvector_scale_scale (a b:R) (v:vector R n) :
a .* (b .* v) = (a * b) .* v.
@@ -390,7 +396,7 @@ Section Rvector_defs.
Rvector_scale_plus_l Rvector_scale_plus_r.
Canonical Rvector_ModuleSpace :=
- ModuleSpace.Pack R_Ring (vector R n) (ModuleSpace.Class R_Ring (vector R n) Rvector_AbelianGroup_mixin Rvector_ModuleSpace_mixin) (vector R n).
+ ModuleSpace.Pack R_Ring (vector R n) (ModuleSpace.Class R_Ring (vector R n) _ Rvector_ModuleSpace_mixin) (vector R n).
Lemma Rvector_scale_inj (c:R) (x y:vector R n) :
c <> 0%R -> c .* x = c .* y -> x = y.
@@ -950,7 +956,7 @@ Section Rvector_defs.
forall m (pf2:(m <= n)%nat),
F (fun v : vector R n =>
forall (i : nat) (pf : (i < m)%nat),
- P i (lt_le_trans _ _ _ pf pf2) (vector_nth i (lt_le_trans _ _ _ pf pf2) v)).
+ P i (Nat.lt_le_trans _ _ _ pf pf2) (vector_nth i (Nat.lt_le_trans _ _ _ pf pf2) v)).
Proof.
intros [???] FA.
induction m; simpl; intros mle.
@@ -964,8 +970,8 @@ Section Rvector_defs.
forall (i : nat) (pf : (i < m)%nat),
P i (Nat.lt_le_trans i (S m) n (pft _ pf) mle) (vector_nth i (Nat.lt_le_trans i (S m) n (pft _ pf) mle) v))).
+ apply filter_imp; intros.
- generalize (lt_n_Sm_le _ _ pf); intros pf2.
- apply le_lt_or_eq in pf2.
+ generalize (proj1 (Nat.lt_succ_r _ _) pf); intros pf2.
+ apply Nat.lt_eq_cases in pf2.
destruct H.
destruct pf2.
* generalize (H0 _ H1).
@@ -992,7 +998,7 @@ Section Rvector_defs.
P i pf (vector_nth i pf v)).
Proof.
intros.
- generalize (Filter_Forall_commute_aux _ _ H H0 n (le_refl n)).
+ generalize (Filter_Forall_commute_aux _ _ H H0 n (Nat.le_refl n)).
destruct H.
apply filter_imp; intros.
specialize (H i pf).
@@ -1064,7 +1070,7 @@ Section Rvector_defs.
destruct H as [?[???]].
generalize filter_true.
apply filter_imp; intros.
- rewrite (vector_zero0 e (Rvector_lim (fun x : vector R n -> Prop => F x))).
+ rewrite (vector_zero0 e (Rvector_lim F)).
rewrite (vector_zero0 e x).
apply ball_center.
- apply Rvector_lim_complete_pos.
diff --git a/coq/ProbTheory/SigmaAlgebras.v b/rocq/ProbTheory/SigmaAlgebras.v
similarity index 98%
rename from coq/ProbTheory/SigmaAlgebras.v
rename to rocq/ProbTheory/SigmaAlgebras.v
index d1fe763e..75c87ab1 100644
--- a/coq/ProbTheory/SigmaAlgebras.v
+++ b/rocq/ProbTheory/SigmaAlgebras.v
@@ -7,7 +7,7 @@ Require Import Morphisms EquivDec Program.
Require Import Utils DVector.
Require Export Event.
-Require Import Lia.
+Require Import Arith Lia.
Set Bullet Behavior "Strict Subproofs".
@@ -1635,13 +1635,6 @@ Qed.
reflexivity.
Qed.
- Lemma lt_S_n_S i1 n pf :
- (Lt.lt_S_n i1 n (Lt.lt_n_S i1 n pf)) = pf.
- Proof.
- apply digit_pf_irrel.
- Qed.
-
-
Lemma generated_rectangle_proj {T} {n} (s : SigmaAlgebra T) (i : ivector (SigmaAlgebra T) n) (e : pre_event (ivector T n)) :
sa_sigma (generated_sa (pre_event_set_ivector_product (ivector_map sa_sigma i))) e ->
sa_sigma (generated_sa (pre_event_set_ivector_product (ivector_map sa_sigma (n:=S n) (s, i)))) (fun '(_, x₂) => e x₂).
@@ -1668,8 +1661,11 @@ Qed.
apply HH.
-- apply H0; intros.
destruct x0.
- specialize (HH (S i0) (Lt.lt_n_S _ _ pf)); simpl in *.
- now rewrite lt_S_n_S in HH.
+ assert (pf':S i0 < S n) by lia.
+ specialize (HH (S i0) pf'); simpl in *.
+ erewrite (ivector_nth_prf_irrelevance _ x).
+ erewrite (ivector_nth_prf_irrelevance _ i1).
+ apply HH.
Qed.
Lemma ivector_rectangles_generate_sa {n} {T}
@@ -1761,7 +1757,7 @@ Qed.
destruct x0.
exists p.
exists (fun v => ivector_Forall2 (fun a0 (x : T) => a0 x) i0 v).
- generalize (H1 0 (NPeano.Nat.lt_0_succ n)); intros.
+ generalize (H1 0 (Nat.lt_0_succ n)); intros.
simpl in H3.
split; trivial.
split.
@@ -1770,23 +1766,27 @@ Qed.
unfold pre_event_set_ivector_product.
exists i0; split.
-- intros.
- specialize (H1 (S i1) (Lt.lt_n_S i1 n pf)).
+ specialize (H1 (S i1) (proj1 (Nat.succ_lt_mono i1 n) pf)).
simpl in H1.
- now rewrite lt_S_n_S in H1.
+ erewrite (ivector_nth_prf_irrelevance _ (ivector_map sa_sigma i)).
+ erewrite (ivector_nth_prf_irrelevance _ i0).
+ apply H1.
-- intros ?.
now rewrite <- ivector_Forall2_nth_iff.
* rewrite H2.
intros ?.
destruct x0.
split; intros.
- generalize (H4 0 (NPeano.Nat.lt_0_succ n)); intros.
+ generalize (H4 0 (Nat.lt_0_succ n)); intros.
simpl in H5.
split; trivial.
-- rewrite <- ivector_Forall2_nth_iff.
intros.
- specialize (H4 (S i2) (Lt.lt_n_S i2 n pf)).
+ specialize (H4 (S i2) ((proj1 (Nat.succ_lt_mono i2 n) pf))).
simpl in H4.
- now rewrite lt_S_n_S in H4.
+ erewrite (ivector_nth_prf_irrelevance _ i0).
+ erewrite (ivector_nth_prf_irrelevance _ i1).
+ apply H4.
-- destruct H4.
rewrite <- ivector_Forall2_nth_iff in H5.
destruct i2.
diff --git a/coq/ProbTheory/SimpleExpectation.v b/rocq/ProbTheory/SimpleExpectation.v
similarity index 99%
rename from coq/ProbTheory/SimpleExpectation.v
rename to rocq/ProbTheory/SimpleExpectation.v
index 635c8eb6..6cecb39a 100644
--- a/coq/ProbTheory/SimpleExpectation.v
+++ b/rocq/ProbTheory/SimpleExpectation.v
@@ -2443,8 +2443,8 @@ Section SimpleConditionalExpectation.
unfold pre_list_collection in H2.
assert (n'bound:(n' < length (map event_pre l))%nat).
{
- destruct (lt_dec n' (length (map event_pre l))%nat); trivial.
- apply not_lt in n0.
+ destruct (Compare_dec.lt_dec n' (length (map event_pre l))%nat); trivial.
+ apply Nat.le_ngt in n0.
unfold pre_event in H2.
rewrite (nth_overflow (map event_pre l) pre_event_none n0) in H2.
elim (Cn'ne H2).
diff --git a/coq/ProbTheory/VectorConditionalExpectation.v b/rocq/ProbTheory/VectorConditionalExpectation.v
similarity index 100%
rename from coq/ProbTheory/VectorConditionalExpectation.v
rename to rocq/ProbTheory/VectorConditionalExpectation.v
diff --git a/coq/ProbTheory/VectorRandomVariable.v b/rocq/ProbTheory/VectorRandomVariable.v
similarity index 99%
rename from coq/ProbTheory/VectorRandomVariable.v
rename to rocq/ProbTheory/VectorRandomVariable.v
index beff87bb..e27896a3 100644
--- a/coq/ProbTheory/VectorRandomVariable.v
+++ b/rocq/ProbTheory/VectorRandomVariable.v
@@ -354,7 +354,7 @@ Section vector_ops.
Qed.
Lemma vecrvsum_rvsum {n} (f : Ts -> vector R n) :
- rv_eq (vecrvsum f) (rvsum (fun i x => match lt_dec i n with
+ rv_eq (vecrvsum f) (rvsum (fun i x => match Compare_dec.lt_dec i n with
| left pf => vector_nth i pf (f x)
| right _ => 0%R
end)
@@ -365,8 +365,8 @@ Section vector_ops.
destruct (f a); simpl.
subst.
rewrite list_sum_sum_n.
- apply (@Hierarchy.sum_n_ext Hierarchy.R_AbelianGroup); intros.
- destruct (lt_dec n (length x)).
+ apply (@Hierarchy.sum_n_ext Hierarchy.R_AbelianMonoid); intros.
+ destruct (Compare_dec.lt_dec n (length x)).
- unfold vector_nth.
match goal with
[|- context [proj1_sig ?x]] => destruct x
@@ -1800,7 +1800,7 @@ Section real_pullback.
rewrite <- vector_Forall2_nth_iff.
intros.
rewrite vector_nth_const.
- destruct (lt_dec i n).
+ destruct (Compare_dec.lt_dec i n).
- rewrite vector_nth_add_to_end_prefix with (pf2 := l).
now rewrite vector_nth_const.
- assert (i = n)%nat by lia.
@@ -1911,7 +1911,7 @@ Section real_pullback.
rewrite <- vector_Forall2_nth_iff.
intros.
rewrite vector_nth_const.
- destruct (lt_dec i n).
+ destruct (Compare_dec.lt_dec i n).
- rewrite vector_nth_add_to_end_prefix with (pf2 := l).
now rewrite vector_nth_const.
- assert (i = n)%nat by lia.
@@ -2012,7 +2012,7 @@ Section real_pullback.
exists (vector_add_to_end x0 (vector_const pre_Ω n)).
split; intros.
-- rewrite vector_nth_map, vector_nth_const.
- destruct (lt_dec i n).
+ destruct (Compare_dec.lt_dec i n).
++ rewrite vector_nth_add_to_end_prefix with (pf2 := l).
rewrite vector_nth_const.
apply sa_all.
@@ -2021,7 +2021,7 @@ Section real_pullback.
now rewrite vector_nth_add_to_end_suffix.
-- intro z.
split; intros.
- ++ destruct (lt_dec i n).
+ ++ destruct (Compare_dec.lt_dec i n).
** rewrite vector_nth_add_to_end_prefix with (pf2 := l).
rewrite vector_nth_const.
apply I.
@@ -2145,7 +2145,7 @@ Section almost.
apply all_almost; intros ??.
now apply vector_Forall2_nth_iff.
+ intros.
- apply lt_dec.
+ apply Compare_dec.lt_dec.
+ unfold vecrvnth.
intros.
now repeat rewrite (vector_nth_ext _ _ pf2 pf1).
diff --git a/coq/ProbTheory/lintp_wrapper.v b/rocq/ProbTheory/lintp_wrapper.v
similarity index 100%
rename from coq/ProbTheory/lintp_wrapper.v
rename to rocq/ProbTheory/lintp_wrapper.v
diff --git a/coq/QLearn/Bellman.v b/rocq/QLearn/Bellman.v
similarity index 100%
rename from coq/QLearn/Bellman.v
rename to rocq/QLearn/Bellman.v
diff --git a/coq/QLearn/Dvoretzky.v b/rocq/QLearn/Dvoretzky.v
similarity index 98%
rename from coq/QLearn/Dvoretzky.v
rename to rocq/QLearn/Dvoretzky.v
index ef72e92d..9a43cbcf 100644
--- a/coq/QLearn/Dvoretzky.v
+++ b/rocq/QLearn/Dvoretzky.v
@@ -372,7 +372,7 @@ Lemma Dvoretzky_rel (n:nat) (theta:R) (X Y : nat -> Ts -> R)
part_prod_n a n m <= part_prod_n b n m.
Proof.
intros.
- destruct (le_dec n m).
+ destruct (Compare_dec.le_dec n m).
- pose (h := (m - n)%nat).
replace (m) with (n + h)%nat by lia.
apply part_prod_n_le_h.
@@ -604,7 +604,7 @@ Qed.
(sum_n_m (fun n0 : nat => -2 * a n0) (S n0) (n + S n0))).
{
intros.
- rewrite <- sum_split; trivial; lia.
+ rewrite <- (@sum_split R_AbelianGroup); trivial; lia.
}
unfold plus in H13; simpl in H13.
exists (-2*x1 - sum_n_m (fun n0 : nat => -2 * a n0) 0 n0 ).
@@ -994,15 +994,15 @@ Section Derman_Sacks.
- rewrite is_lim_seq_incr_n with (N := N).
apply is_lim_seq_ext with
(u := fun n =>
- (sum_n_m (fun j : nat => (delta j - c j) / B j) 0 (n + N)) -
- (sum_n_m (fun j : nat => (delta j - c j) / B j) 0 (N-1))).
+ (@sum_n_m R_AbelianGroup (fun j : nat => (delta j - c j) / B j) 0 (n + N)) -
+ (@sum_n_m R_AbelianGroup (fun j : nat => (delta j - c j) / B j) 0 (N-1))).
intros.
- rewrite sum_split with (m := (N-1)%nat); try lia.
+ rewrite (@sum_split R_AbelianGroup) with (m := (N-1)%nat); try lia.
unfold plus; simpl; ring_simplify.
replace (S (N-1))%nat with N by lia; trivial.
apply is_lim_seq_minus with
(l1 := m_infty)
- (l2 := sum_n_m (fun j : nat => (delta j - c j) / B j) 0 (N - 1)).
+ (l2 := @sum_n_m R_AbelianGroup (fun j : nat => (delta j - c j) / B j) 0 (N - 1)).
+ now rewrite is_lim_seq_incr_n with (N := N) in H13.
+ apply is_lim_seq_const.
+ now unfold is_Rbar_minus, is_Rbar_plus; simpl.
@@ -2829,7 +2829,7 @@ Theorem Dvoretzky_DS_scale_prop
(m n i:nat) pf1 pf2 ts :
vector_nth i pf1 (DS_Xn_v X0 T Y m ts) = vector_nth i pf2 (DS_Xn_v X0 T Y n ts).
Proof.
- destruct (le_dec n m).
+ destruct (Compare_dec.le_dec n m).
- now apply DS_Xn_v_same_prefix_le_helper.
- symmetry.
apply DS_Xn_v_same_prefix_le_helper.
@@ -3776,15 +3776,14 @@ Lemma Dvoretzky_converge_W_alpha_beta_uniform (W w α β: nat -> Ts -> R)
almostR2 prts Rbar_le
(ConditionalExpectation prts (filt_sub n) (rvsqr (w n)))
(const (Rsqr C)))) ->
- almost prts (fun ω : Ts => is_lim_seq (sum_n(fun n : nat => α n ω)) p_infty) ->
- almost prts (fun ω : Ts => is_lim_seq (sum_n (fun n : nat => β n ω)) p_infty) ->
+ almost prts (fun ω : Ts => is_lim_seq (@sum_n R_AbelianGroup (fun n : nat => α n ω)) p_infty) ->
+ almost prts (fun ω : Ts => is_lim_seq (@sum_n R_AbelianGroup (fun n : nat => β n ω)) p_infty) ->
almost prts (fun ω => ex_series (fun n => Rsqr (β n ω))) ->
- (exists epsilon : posreal, eventually (fun n => almostR2 prts Rbar_lt (fun ω => Lim_seq (sum_n (fun nn => rvsqr (β (nn+n)%nat) ω))) (const epsilon))) ->
+ (exists epsilon : posreal, eventually (fun n => almostR2 prts Rbar_lt (fun ω => Lim_seq (@sum_n R_AbelianGroup (fun nn => rvsqr (β (nn+n)%nat) ω))) (const epsilon))) ->
(forall n, rv_eq (W (S n)) (rvplus (rvmult (rvminus (const 1) (α n)) (W n)) (rvmult (w n) (β n)))) ->
almost _ (fun ω => is_lim_seq (fun n => W n ω) (Finite 0)).
Proof.
intros condexpw condexpw2 alpha_inf beta_inf beta_sqr [ϵ beta_bounded] (* W0 *) Wrel.
-
assert (svy1b: forall n : nat, IsFiniteExpectation prts (rvsqr (β n))).
{
intros.
@@ -3792,7 +3791,10 @@ Proof.
}
eapply (@Dvoretzky_converge_W_alpha_beta_isf_seq_sum W w α β F isfilt filt_sub adaptZ adapt_alpha adapt_beta rvw); eauto.
-
+ change (is_finite
+ (Lim_seq
+ (@sum_n R_AbelianGroup
+ (fun n : nat => @FiniteExpectation Ts dom prts (@rvsqr Ts (β n)) (svy1b n))))).
generalize (sum_expectation prts (fun n => rvsqr (β n))); intros HH.
assert (rv2 : forall n, RandomVariable dom borel_sa (rvsqr (β n))).
{
@@ -3812,7 +3814,7 @@ Proof.
unfold A3'.
unfold Series.
apply Rbar_real_rv.
- cut (RandomVariable dom Rbar_borel_sa (fun omega : Ts => ELim_seq (sum_n (fun n : nat => (β n omega)²)))).
+ cut (RandomVariable dom Rbar_borel_sa (fun omega : Ts => ELim_seq (@sum_n R_AbelianGroup (fun n : nat => (β n omega)²)))).
{
apply RandomVariable_proper; try reflexivity; intros ?.
now rewrite <- Elim_seq_fin.
@@ -3878,7 +3880,7 @@ Proof.
specialize (betaN (S N)).
cut_to betaN; try lia.
- pose (btail := (rvplus (fun ω => sum_n (fun nn : nat => rvsqr (β nn) ω) N)
+ pose (btail := (rvplus (fun ω => @sum_n R_AbelianGroup (fun nn : nat => rvsqr (β nn) ω) N)
(const (pos ϵ)))).
assert (btail_rv : RandomVariable dom Rbar_borel_sa btail).
@@ -3892,7 +3894,7 @@ Proof.
apply IsFiniteExpectation_Rbar.
apply Rbar_finexp_finexp.
{
- cut (RandomVariable dom Rbar_borel_sa (fun omega : Ts => ELim_seq (sum_n (fun n : nat => (β n omega)²)))).
+ cut (RandomVariable dom Rbar_borel_sa (fun omega : Ts => ELim_seq (@sum_n R_AbelianGroup (fun n : nat => (β n omega)²)))).
{
apply RandomVariable_proper; try reflexivity; intros ?.
now rewrite <- Elim_seq_fin.
@@ -3901,7 +3903,7 @@ Proof.
}
apply (Rbar_IsFiniteExpectation_nnf_bounded_almost _ _ btail).
- intros ?.
- generalize (Lim_seq_le (fun _ => 0) (sum_n (fun n : nat => (β n x)²)))
+ generalize (Lim_seq_le (fun _ => 0) (@sum_n R_AbelianGroup (fun n : nat => (β n x)²)))
; intros HHH.
cut_to HHH.
+ rewrite Lim_seq_const in HHH.
@@ -3916,13 +3918,13 @@ Proof.
apply almost_impl.
apply all_almost; intros ω bsqr_ex bsqr_bound.
rewrite <- (Lim_seq_incr_n _ (S N)).
- assert (eqq:Lim_seq (fun n : nat => sum_n (fun n0 : nat => (β n0 ω)²) (n + S N)) =
+ assert (eqq:Lim_seq (fun n : nat => @sum_n R_AbelianGroup (fun n0 : nat => (β n0 ω)²) (n + S N)) =
- (Lim_seq (fun n => sum_n (fun nn : nat => rvsqr (β nn) ω) N +
- (sum_n (fun nn : nat => rvsqr (β (nn + S N)%nat) ω) n)))).
+ (Lim_seq (fun n => @sum_n R_AbelianGroup (fun nn : nat => rvsqr (β nn) ω) N +
+ (@sum_n R_AbelianGroup (fun nn : nat => rvsqr (β (nn + S N)%nat) ω) n)))).
{
apply Lim_seq_ext; intros n.
- generalize (sum_split (fun n0 : nat => (β n0 ω)²) 0 ((n + S N)%nat) N)
+ generalize (@sum_split R_AbelianGroup (fun n0 : nat => (β n0 ω)²) 0 ((n + S N)%nat) N)
; intros HHH.
cut_to HHH; try lia.
unfold sum_n.
@@ -3931,12 +3933,19 @@ Proof.
f_equal.
now rewrite sum_shift.
}
+ change ( Rbar_le (Lim_seq (fun n : nat => @sum_n R_AbelianGroup (fun n0 : nat => (β n0 ω)²) (n + S N)))
+ (rvplus (fun ω0 : Ts => @sum_n R_AbelianGroup (fun nn : nat => rvsqr (β nn) ω0) N) (const ϵ) ω)).
rewrite eqq.
rewrite Lim_seq_plus.
+ unfold rvplus.
rewrite Lim_seq_const.
- replace (Finite (sum_n (fun nn : nat => rvsqr (β nn) ω) N + const (pos ϵ) ω)) with
- (Rbar_plus (sum_n (fun nn : nat => rvsqr (β nn) ω) N) (const (pos ϵ) ω)) by reflexivity.
+ replace (Finite
+ (Rplus
+ (@sum_n (AbelianGroup.AbelianMonoid R_AbelianGroup)
+ (fun nn : nat => @rvsqr Ts (β nn) ω) N)
+ (pos (@const posreal Ts ϵ ω))))
+ with
+ (Rbar_plus (@sum_n R_AbelianGroup (fun nn : nat => rvsqr (β nn) ω) N) (const (pos ϵ) ω)) by reflexivity.
apply Rbar_plus_le_compat.
{ apply Rbar_le_refl. }
now apply Rbar_lt_le.
@@ -4139,16 +4148,18 @@ Proof.
+ destruct n; [lia |].
rewrite <- (Lim_seq_incr_n (sum_n (fun n0 : nat => rvsqr (β n0) ω)) (S n)).
apply Lim_seq_le; intros.
- generalize (sum_split (fun n0 : nat => (β n0 ω)²) 0 ((n0 + S n)%nat) n)
+ generalize (@sum_split R_AbelianGroup (fun n0 : nat => (β n0 ω)²) 0 ((n0 + S n)%nat) n)
; intros HHH.
- cut_to HHH; try lia.
- unfold sum_n, rvsqr.
- rewrite HHH.
- unfold plus, rvsqr; simpl.
- rewrite sum_shift.
- cut (0 <= sum_n_m (fun n1 : nat => (β n1 ω)²) 0 n); try lra.
- apply sum_n_m_pos; intros.
- apply Rle_0_sqr.
+ cut_to HHH; try lia.
+ unfold sum_n, rvsqr.
+ etransitivity; [| right; symmetry; apply HHH].
+ unfold plus, rvsqr; simpl.
+ rewrite sum_shift.
+ change (@sum_n_m R_AbelianGroup (fun nn : nat => (β (nn + S n)%nat ω)²) 0 n0 <=
+ @sum_n_m R_AbelianGroup (fun n1 : nat => (β n1 ω)²) 0 n + @sum_n_m R_AbelianGroup (fun n1 : nat => (β (n1 + S n)%nat ω)²) 0 n0).
+ cut (0 <= @sum_n_m R_AbelianGroup (fun n1 : nat => (β n1 ω)²) 0 n); try lra.
+ apply sum_n_m_pos; intros.
+ apply Rle_0_sqr.
+ simpl.
apply Rmax_r.
Qed.
diff --git a/coq/QLearn/Tsitsiklis.v b/rocq/QLearn/Tsitsiklis.v
similarity index 99%
rename from coq/QLearn/Tsitsiklis.v
rename to rocq/QLearn/Tsitsiklis.v
index 3a235ba3..0cc2eca5 100644
--- a/coq/QLearn/Tsitsiklis.v
+++ b/rocq/QLearn/Tsitsiklis.v
@@ -202,7 +202,7 @@ Proof.
}
pose (tau_coll k t j :=
- match (le_dec j t) with
+ match (Compare_dec.le_dec j t) with
| left pf => event_lt (rv := rvB j t pf) (F t) (B j) (INR k)
| _ => Ω
end).
@@ -516,7 +516,7 @@ Proof.
}
pose (tau_coll k t j :=
- match (le_dec j t) with
+ match (Compare_dec.le_dec j t) with
| left pf => event_lt (rv := rvB j t pf) (F t) (B j) (INR k)
| _ => Ω
end).
@@ -3991,7 +3991,7 @@ Qed.
{
apply almost_bounded_forall.
intros.
- - apply lt_dec.
+ - apply Compare_dec.lt_dec.
- intros.
apply is_lim_seq_ext with (u := (fun k : nat => WW i pf1 k 0%nat x)); trivial.
intros.
@@ -4128,7 +4128,7 @@ Qed.
destruct (H21 H22).
exists ((1 + ε) * (G x0 x)).
intros.
- destruct (le_dec t x0).
+ destruct (Compare_dec.le_dec t x0).
- apply Rle_trans with (r2 := M x0 x).
+ unfold M.
apply Rmax_seq_map_monotone.
@@ -5001,7 +5001,7 @@ Qed.
{
apply almost_bounded_forall.
intros.
- - apply lt_dec.
+ - apply Compare_dec.lt_dec.
- intros.
apply is_lim_seq_ext with (u := (fun k : nat => WW i pf1 k 0%nat x)); trivial.
intros.
@@ -5145,7 +5145,7 @@ Qed.
destruct (H21 H22).
exists ((1 + ε) * (G x0 x)).
intros.
- destruct (le_dec t x0).
+ destruct (Compare_dec.le_dec t x0).
- apply Rle_trans with (r2 := M x0 x).
+ unfold M.
apply Rmax_seq_map_monotone.
@@ -6022,7 +6022,7 @@ Qed.
{
apply almost_bounded_forall.
intros.
- - apply lt_dec.
+ - apply Compare_dec.lt_dec.
- intros.
apply is_lim_seq_ext with (u := (fun k : nat => WW i pf1 k 0%nat x)); trivial.
intros.
@@ -6166,7 +6166,7 @@ Qed.
destruct (H21 H22).
exists ((1 + ε) * (G x0 x)).
intros.
- destruct (le_dec t x0).
+ destruct (Compare_dec.le_dec t x0).
- apply Rle_trans with (r2 := M x0 x).
+ unfold M.
apply Rmax_seq_map_monotone.
@@ -6767,7 +6767,7 @@ Qed.
{
apply almost_bounded_forall.
intros.
- - apply le_dec.
+ - apply Compare_dec.le_dec.
- intros.
rewrite (digit_pf_irrel _ _ pf2 pf1).
apply H14.
@@ -7378,7 +7378,7 @@ Qed.
{
apply almost_bounded_forall.
intros.
- - apply le_dec.
+ - apply Compare_dec.le_dec.
- intros.
rewrite (digit_pf_irrel _ _ pf2 pf1).
apply H14.
@@ -7717,7 +7717,7 @@ Qed.
apply H11.
+ apply Rle_ge, Rvector_max_abs_nonneg.
- intros.
- apply lt_dec.
+ apply Compare_dec.lt_dec.
- intros i pf1 pf2 x.
apply is_lim_seq_ext.
intros.
@@ -8001,7 +8001,7 @@ Qed.
apply H11.
+ apply Rle_ge, Rvector_max_abs_nonneg.
- intros.
- apply lt_dec.
+ apply Compare_dec.lt_dec.
- intros i pf1 pf2 x.
apply is_lim_seq_ext.
intros.
@@ -8972,7 +8972,7 @@ Qed.
* apply H14.
lia.
+ intros.
- apply le_dec.
+ apply Compare_dec.le_dec.
+ intros.
apply H14.
- revert alpha_betaprop; apply almost_impl.
@@ -9038,7 +9038,7 @@ Qed.
* apply H0.
lia.
+ intros.
- apply le_dec.
+ apply Compare_dec.le_dec.
+ intros.
apply H14.
- revert H1; apply almost_impl.
diff --git a/coq/QLearn/infprod.v b/rocq/QLearn/infprod.v
similarity index 97%
rename from coq/QLearn/infprod.v
rename to rocq/QLearn/infprod.v
index 30431b92..6b82ec96 100644
--- a/coq/QLearn/infprod.v
+++ b/rocq/QLearn/infprod.v
@@ -1,4 +1,4 @@
-Require Import Reals Sums Lra Lia.
+Require Import ZArith Reals Sums Lra Lia.
(* Require Import Coquelicot.Hierarchy Coquelicot.Series Coquelicot.Lim_seq Coquelicot.Rbar.*)
Require Import Coquelicot.Coquelicot.
Require Import LibUtils.
@@ -493,7 +493,7 @@ Proof.
replace (S n2 - n1)%nat with ((S m - n1) + (S n2 - S m))%nat by lia.
rewrite seq_plus.
rewrite List.fold_right_app.
- rewrite fold_right_plus_acc.
+ rewrite (@fold_right_plus_acc G).
now replace (n1 + (S m - n1))%nat with (S m) by lia.
Qed.
@@ -534,7 +534,7 @@ Qed.
now simpl.
+ intros.
unfold sum_n.
- rewrite sum_split with (m := (nk-1)%nat); try lia.
+ rewrite (@sum_split R_AbelianGroup) with (m := (nk-1)%nat); try lia.
apply Rplus_eq_compat_l.
replace (S (nk - 1)) with (nk) by lia.
apply sum_n_m_shift.
@@ -562,7 +562,8 @@ Qed.
cut (ex_lim_seq (fun n : nat => sum_n_m α 0 (nk + S n) - sum_n_m α 0 nk)).
{
apply ex_lim_seq_ext; intros.
- rewrite (sum_split_plus α 0 nk (S n)); try lia.
+ change ((@sum_n_m R_AbelianGroup α 0 (nk + S n) - @sum_n_m R_AbelianGroup α 0 nk) = @sum_n_m R_AbelianGroup α (S nk) (n + S nk)).
+ rewrite (@sum_split_plus R_AbelianGroup α 0 nk (S n) ltac:(lia) ltac:(lia)).
unfold plus; simpl.
field_simplify.
f_equal.
@@ -885,7 +886,7 @@ Proof.
destruct H0 as [k H0]; destruct H0.
exists k.
split; trivial; intros.
- destruct (lt_dec m n).
+ destruct (Compare_dec.lt_dec m n).
+ remember (n - S m)%nat as nm.
replace (n) with (S m + nm)%nat; [|lia].
rewrite initial_seg_prod_n; trivial.
@@ -948,7 +949,7 @@ Proof.
specialize (IHk H1).
apply Rle_trans with (r2 := part_prod_n (pos_sq_fun F) (m + k) n); trivial.
replace (m + S k)%nat with (S (m+k)%nat) by lia.
- destruct (le_gt_dec (S (m+k)) n).
+ destruct (Compare_dec.le_gt_dec (S (m+k)) n).
+ apply max_bounded1_pre_le; trivial.
intros; apply pos_sq_bounded1; trivial.
+ rewrite (part_prod_n_1 (pos_sq_fun F) (S (m + k)%nat)) ; [|lia].
@@ -1110,7 +1111,7 @@ Section Dvoretsky.
Theorem Dvoretzky4_0 (F: nat -> posreal) (sigma V : nat -> R) :
(forall (n:nat), V (S n) <= (F n) * (V n) + (sigma n)) ->
(forall (n:nat),
- V (S n) <= sum_n (fun k => (sigma k)*(part_prod_n F (S k) n)) n +
+ V (S n) <= @sum_n R_AbelianGroup (fun k => (sigma k)*(part_prod_n F (S k) n)) n +
(V 0%nat)*(part_prod_n F 0 n)).
Proof.
intros.
@@ -1145,8 +1146,8 @@ Qed.
Lemma sum_bound_prod_A (F : nat -> posreal) (sigma : nat -> R) (A : R) (n m:nat) :
(forall r s, part_prod_n (pos_sq_fun F) r s <= A) ->
- sum_n_m (fun k => (Rsqr (sigma k))*(part_prod_n (pos_sq_fun F) (S k) n)) (S m) n <=
- (sum_n_m (fun k => Rsqr (sigma k)) (S m) n) * A.
+ @sum_n_m R_AbelianGroup (fun k => (Rsqr (sigma k))*(part_prod_n (pos_sq_fun F) (S k) n)) (S m) n <=
+ (@sum_n_m R_AbelianGroup (fun k => Rsqr (sigma k)) (S m) n) * A.
Proof.
intros.
rewrite <- sum_n_m_mult_r with (a := A).
@@ -1160,8 +1161,8 @@ Qed.
Lemma sum_bound3_max (F : nat -> posreal) (sigma : nat -> R) (n m:nat) :
(S m <= n)%nat ->
- sum_n (fun k => (Rsqr (sigma k))*(part_prod_n (pos_sq_fun F) (S k) n)) m <=
- (sum_n (fun k => (Rsqr (sigma k))) m) * (max_prod_fun (pos_sq_fun F) (S m) n).
+ @sum_n R_AbelianGroup (fun k => (Rsqr (sigma k))*(part_prod_n (pos_sq_fun F) (S k) n)) m <=
+ (@sum_n R_AbelianGroup (fun k => (Rsqr (sigma k))) m) * (max_prod_fun (pos_sq_fun F) (S m) n).
Proof.
intros.
rewrite <- sum_n_mult_r with (a := (max_prod_fun (pos_sq_fun F) (S m) n)).
@@ -1176,8 +1177,8 @@ Theorem Dvoretzky4_8_5 (F : nat -> posreal) (sigma V: nat -> R) (n m:nat) (A:R):
(forall (n:nat), Rsqr (V (S n)) <= (pos_sq_fun F) n * Rsqr (V n) + Rsqr (sigma n)) ->
(m
Rsqr (V (S n)) <=
- ( sum_n_m (fun k => Rsqr (sigma k)) (S m) n) * A +
- (Rsqr (V 0%nat) + sum_n (fun k => (Rsqr (sigma k))) m) *
+ ( @sum_n_m R_AbelianGroup (fun k => Rsqr (sigma k)) (S m) n) * A +
+ (Rsqr (V 0%nat) + @sum_n R_AbelianGroup (fun k => (Rsqr (sigma k))) m) *
(max_prod_fun (pos_sq_fun F) (S m) n).
Proof.
intros F1 Vsqle mn.
@@ -1185,6 +1186,7 @@ Proof.
intros.
specialize (H Vsqle n).
unfold sum_n in H.
+
rewrite (sum_split _ _ _ m) in H; trivial; [|lia].
generalize (sum_bound_prod_A F sigma A n m F1); intros.
generalize (max_prod_le (pos_sq_fun F) 0 (S m) n); intros.
@@ -1204,8 +1206,8 @@ Lemma sum_bound_prod_A_sigma1
(F : nat -> posreal) (sigma : nat -> R) (A : R) (n m:nat) :
(forall r s, part_prod_n (pos_sq_fun F) r s <= A) ->
(forall n, 0 <= sigma n) ->
- sum_n_m (fun k => (sigma k)*(part_prod_n (pos_sq_fun F) (S k) n)) (S m) n <=
- (sum_n_m sigma (S m) n) * A.
+ @sum_n_m R_AbelianGroup (fun k => (sigma k)*(part_prod_n (pos_sq_fun F) (S k) n)) (S m) n <=
+ (@sum_n_m R_AbelianGroup sigma (S m) n) * A.
Proof.
intros.
rewrite <- sum_n_m_mult_r with (a := A).
@@ -1218,8 +1220,8 @@ Qed.
Lemma sum_bound3_max_sigma1 (F : nat -> posreal) (sigma : nat -> R) (n m:nat) :
(S m <= n)%nat ->
(forall n, 0 <= sigma n) ->
- sum_n (fun k => (sigma k)*(part_prod_n (pos_sq_fun F) (S k) n)) m <=
- (sum_n sigma m) * (max_prod_fun (pos_sq_fun F) (S m) n).
+ @sum_n R_AbelianGroup (fun k => (sigma k)*(part_prod_n (pos_sq_fun F) (S k) n)) m <=
+ (@sum_n R_AbelianGroup sigma m) * (max_prod_fun (pos_sq_fun F) (S m) n).
Proof.
intros.
rewrite <- sum_n_mult_r with (a := (max_prod_fun (pos_sq_fun F) (S m) n)).
@@ -1238,8 +1240,8 @@ Theorem Dvoretzky4_8_5_V1 (F : nat -> posreal) (sigma V: nat -> R) (n m:nat) (A:
(forall (n:nat), 0 <= sigma n) ->
(m
V (S n) <=
- (sum_n_m sigma (S m) n) * A +
- (V 0%nat + sum_n sigma m) *
+ (@sum_n_m R_AbelianGroup sigma (S m) n) * A +
+ (V 0%nat + @sum_n R_AbelianGroup sigma m) *
(max_prod_fun (pos_sq_fun F) (S m) n).
Proof.
intros F1 Vle Vpos sigma_pos mn.
@@ -1269,12 +1271,12 @@ Theorem Dvoretzky4_8_5_1 (F : nat -> posreal) (sigma V: nat -> R) (n m:nat) (A s
is_series (fun n => Rsqr (sigma n)) sigmasum ->
(m
Rsqr (V (S n)) <=
- (sum_n_m (fun k => Rsqr (sigma k)) (S m) n) * A +
+ (@sum_n_m R_AbelianGroup (fun k => Rsqr (sigma k)) (S m) n) * A +
(Rsqr (V 0%nat) + sigmasum) * (max_prod_fun (pos_sq_fun F) (S m) n).
Proof.
intros.
generalize (Dvoretzky4_8_5 F sigma V n m A H H0 H2); intros.
- assert (sum_n (fun k : nat => (sigma k)²) m <= sigmasum).
+ assert (@sum_n R_AbelianGroup (fun k : nat => (sigma k)²) m <= sigmasum).
- assert (H1' := H1).
apply is_series_unique in H1.
assert (ex_series (fun k : nat => (sigma k)²)).
@@ -1300,12 +1302,12 @@ Theorem Dvoretzky4_8_5_1_V1 (F : nat -> posreal) (sigma V: nat -> R) (n m:nat) (
is_series sigma sigmasum ->
(m
V (S n) <=
- (sum_n_m sigma (S m) n) * A +
+ (@sum_n_m R_AbelianGroup sigma (S m) n) * A +
(V 0%nat + sigmasum) * (max_prod_fun (pos_sq_fun F) (S m) n).
Proof.
intros.
generalize (Dvoretzky4_8_5_V1 F sigma V n m A H H0 H2 H1 H4); intros.
- assert (sum_n sigma m <= sigmasum).
+ assert (@sum_n R_AbelianGroup sigma m <= sigmasum).
- assert (H3' := H3).
apply is_series_unique in H3.
assert (ex_series sigma).
diff --git a/coq/QLearn/jaakkola_vector.v b/rocq/QLearn/jaakkola_vector.v
similarity index 99%
rename from coq/QLearn/jaakkola_vector.v
rename to rocq/QLearn/jaakkola_vector.v
index 44c3d33f..13faa82a 100644
--- a/coq/QLearn/jaakkola_vector.v
+++ b/rocq/QLearn/jaakkola_vector.v
@@ -2879,7 +2879,7 @@ Section jaakola_vector2.
apply all_almost; intros ??.
now apply lim_seq_maxabs0_b.
- intros.
- apply lt_dec.
+ apply Compare_dec.lt_dec.
- intros.
revert H0.
apply is_lim_seq_ext.
@@ -3030,16 +3030,16 @@ Section jaakola_vector2.
repeat rewrite vector_map_create.
repeat rewrite vector_nth_create.
repeat rewrite vector_nth_map.
- assert (pf3 : (i < (S N))%nat).
+ assert (pf3 : (0 + i < 0 + (S N))%nat).
{
lia.
}
simpl.
repeat match goal with
- [|- context [@vector_nth ?RR ?nn i (plus_lt_compat_l ?pi ?pn ?pc ?pp) ?v]] =>
- replace (@vector_nth RR nn i (plus_lt_compat_l pi pn pc pp) v)
- with (@vector_nth RR nn i pf3 v)
- by apply vector_nth_ext
+ [|- context [@vector_nth R (S N) i (@proj1 ?h1 ?h2 (Nat.add_lt_mono_l ?pi ?pn ?pc) ?pp) ?v]] =>
+ replace (@vector_nth R (S N) i (@proj1 h1 h2 (Nat.add_lt_mono_l pi pn pc) pp) v)
+ with (@vector_nth R (S N) i pf3 v)
+ by now apply vector_nth_ext
end.
lra.
}
@@ -3274,7 +3274,7 @@ Section jaakola_vector2.
apply H18.
}
apply almost_bounded_forall; intros.
- - apply lt_dec.
+ - apply Compare_dec.lt_dec.
- revert H19.
apply is_lim_seq_ext.
intros.
@@ -4458,7 +4458,7 @@ Section jaakola_vector2.
apply bounded_forall_almost with (n := i) (pf := pf) in H0.
- apply H0.
- intros.
- apply lt_dec.
+ apply Compare_dec.lt_dec.
- intros.
erewrite vector_nth_ext; try apply H11.
}
@@ -4471,7 +4471,7 @@ Section jaakola_vector2.
apply eventually_impl.
apply all_eventually; intros.
apply almost_bounded_forall.
- + intros; apply lt_dec.
+ + intros; apply Compare_dec.lt_dec.
+ intros.
erewrite vector_nth_ext; try apply H12.
+ apply H11.
@@ -4486,7 +4486,7 @@ Section jaakola_vector2.
apply bounded_forall_almost with (n := i) (pf := pf) in H1.
- apply H1.
- intros.
- apply lt_dec.
+ apply Compare_dec.lt_dec.
- intros.
erewrite vector_nth_ext; try apply H11.
}
@@ -4499,7 +4499,7 @@ Section jaakola_vector2.
apply eventually_impl.
apply all_eventually; intros.
apply almost_bounded_forall.
- + intros; apply lt_dec.
+ + intros; apply Compare_dec.lt_dec.
+ intros.
erewrite vector_nth_ext; try apply H12.
+ apply H11.
@@ -5192,7 +5192,7 @@ Section jaakola_vector2.
apply bounded_forall_almost with (n := i) (pf := pf) in H0.
- apply H0.
- intros.
- apply lt_dec.
+ apply Compare_dec.lt_dec.
- intros.
erewrite vector_nth_ext; try apply H10.
}
@@ -5205,7 +5205,7 @@ Section jaakola_vector2.
apply eventually_impl.
apply all_eventually; intros.
apply almost_bounded_forall.
- + intros; apply lt_dec.
+ + intros; apply Compare_dec.lt_dec.
+ intros.
erewrite vector_nth_ext; try apply H11.
+ apply H10.
@@ -5220,7 +5220,7 @@ Section jaakola_vector2.
apply bounded_forall_almost with (n := i) (pf := pf) in H1.
- apply H1.
- intros.
- apply lt_dec.
+ apply Compare_dec.lt_dec.
- intros.
erewrite vector_nth_ext; try apply H10.
}
@@ -5233,7 +5233,7 @@ Section jaakola_vector2.
apply eventually_impl.
apply all_eventually; intros.
apply almost_bounded_forall.
- + intros; apply lt_dec.
+ + intros; apply Compare_dec.lt_dec.
+ intros.
erewrite vector_nth_ext; try apply H11.
+ apply H10.
@@ -5395,7 +5395,7 @@ Section jaakola_vector2.
unfold rvscale.
now rewrite Rmult_comm.
- intros.
- apply lt_dec.
+ apply Compare_dec.lt_dec.
- intros.
assert (FiniteConditionalExpectation prts sub (fun ω : Ts => vector_nth i pf1 (pos_Rvector_mult (f ω) W)) x =
FiniteConditionalExpectation prts sub (fun ω : Ts => vector_nth i pf2 (pos_Rvector_mult (f ω) W)) x).
@@ -5676,7 +5676,7 @@ Proof.
apply Rle_abs.
+ apply almost_forall; intros.
apply almost_bounded_forall; intros.
- * apply lt_dec.
+ * apply Compare_dec.lt_dec.
* eapply Rbar_le_trans; cycle 1.
apply H3.
apply slln.eq_Rbar_le.
@@ -6221,7 +6221,7 @@ Section qlearn.
intros ?.
now apply sum_n_ext.
- apply bounded_forall_almost.
- + intros; apply lt_dec.
+ + intros; apply Compare_dec.lt_dec.
+ intros ????.
apply ex_series_ext.
intros.
diff --git a/coq/QLearn/lim_add.v b/rocq/QLearn/lim_add.v
similarity index 100%
rename from coq/QLearn/lim_add.v
rename to rocq/QLearn/lim_add.v
diff --git a/coq/QLearn/qlearn.v b/rocq/QLearn/qlearn.v
similarity index 99%
rename from coq/QLearn/qlearn.v
rename to rocq/QLearn/qlearn.v
index 4d0abada..8d158b4c 100644
--- a/coq/QLearn/qlearn.v
+++ b/rocq/QLearn/qlearn.v
@@ -115,7 +115,7 @@ Section rv_expressible.
split.
+ intros.
generalize (H1 n); intros.
- destruct (le_dec x0 n).
+ destruct (Compare_dec.le_dec x0 n).
* specialize (H2 l).
rewrite Rabs_lt_between in H2.
lra.
@@ -174,7 +174,7 @@ Section rv_expressible.
specialize (H0 M).
destruct H0.
simpl.
- destruct (le_dec x n).
+ destruct (Compare_dec.le_dec x n).
* now apply H0.
* assert (n < x)%nat by lia.
generalize (increasing_seq f H n (x - n)%nat); intros.
@@ -324,7 +324,7 @@ End rv_expressible.
generalize (scal r y2) as d.
intros.
unfold minus.
- rewrite opp_plus.
+ rewrite (@opp_plus X).
rewrite plus_assoc.
rewrite plus_assoc.
f_equal.
@@ -577,7 +577,7 @@ End rv_expressible.
(forall n, 0 <= α n <= 1) ->
(forall n, 0 <= (1-gamma)* α (n+k)%nat < 1) ->
is_lim_seq α 0 ->
- is_lim_seq (sum_n α) p_infty ->
+ is_lim_seq (@sum_n R_AbelianGroup α) p_infty ->
is_lim_seq (fun n => prod_f_R0 (fun m => g_alpha gamma (α (m + k)%nat)) n) 0.
Proof.
intros.
@@ -597,7 +597,7 @@ End rv_expressible.
reflexivity.
- unfold l1_divergent.
apply is_lim_seq_ext with
- (u := (fun m => (1-gamma) * (sum_n (fun n => α (n + k)%nat) m))).
+ (u := (fun m => (1-gamma) * (@sum_n R_AbelianGroup (fun n => α (n + k)%nat) m))).
+ intros.
generalize (sum_n_mult_l (1-gamma) (fun n => α (n + k)%nat) n); intros.
unfold Hierarchy.mult in H6; simpl in H6.
@@ -613,9 +613,9 @@ End rv_expressible.
lia.
-- assert (k > 0)%nat by lia.
apply is_lim_seq_ext with
- (u := fun m => minus (sum_n α (m + k)%nat) (sum_n α (k-1)%nat)).
+ (u := fun m => minus (@sum_n R_AbelianGroup α (m + k)%nat) (@sum_n R_AbelianGroup α (k-1)%nat)).
++ intros.
- rewrite <- sum_n_m_sum_n; trivial; try lia.
+ rewrite <- (@sum_n_m_sum_n R_AbelianGroup); trivial; try lia.
replace (S (k-1)%nat) with (k) by lia.
apply sum_n_m_shift.
++ apply is_lim_seq_minus with
@@ -698,10 +698,10 @@ End rv_expressible.
unfold ex_finite_lim_seq.
exists (x - (sum_n g_α N)).
apply is_lim_seq_ext_loc with
- (u := (fun n => minus (sum_n g_α n) (sum_n g_α N))).
+ (u := (fun n => minus (@sum_n R_AbelianGroup g_α n) (@sum_n R_AbelianGroup g_α N))).
exists N.
intros.
- rewrite sum_n_m_sum_n; trivial.
+ rewrite (@sum_n_m_sum_n R_AbelianGroup); trivial.
apply is_lim_seq_minus'; trivial.
apply is_lim_seq_const.
unfold ex_finite_lim_seq in H4.
@@ -742,7 +742,7 @@ End rv_expressible.
assert (n = 0%nat ) by lia.
rewrite H2.
apply H0.
- - destruct (le_dec n N).
+ - destruct (Compare_dec.le_dec n N).
+ apply IHN; trivial.
intros; apply H; lia.
rewrite sum_Sn in H0.
@@ -2230,6 +2230,7 @@ End rv_expressible.
** rewrite (scal_minus_distr_r 1).
unfold minus.
rewrite <- plus_assoc.
+ change (xstar = plus (scal 1 xstar) (plus (opp (scal (α n) xstar)) (scal (α n) xstar))).
rewrite plus_opp_l.
rewrite plus_zero_r.
now generalize (scal_one xstar).
@@ -3436,7 +3437,9 @@ End rv_expressible.
now rewrite IHn.
+ apply is_lim_seq_const.
- intros.
- rewrite minus_eq_zero, norm_zero.
+ rewrite minus_eq_zero.
+ change (norm (@zero R_NormedModule) <= gamma * norm (minus x y)).
+ rewrite norm_zero.
apply Rmult_le_pos; try lra.
apply norm_ge_0.
- intros; apply prod_f_R0_proper; [|trivial].
diff --git a/coq/QLearn/qlearn_aux.v b/rocq/QLearn/qlearn_aux.v
similarity index 100%
rename from coq/QLearn/qlearn_aux.v
rename to rocq/QLearn/qlearn_aux.v
diff --git a/coq/QLearn/qlearn_redux.v b/rocq/QLearn/qlearn_redux.v
similarity index 99%
rename from coq/QLearn/qlearn_redux.v
rename to rocq/QLearn/qlearn_redux.v
index 0e8ac9a2..5f1882cc 100644
--- a/coq/QLearn/qlearn_redux.v
+++ b/rocq/QLearn/qlearn_redux.v
@@ -1906,7 +1906,7 @@ Lemma Dvoretzky_converge_Y (C : R) (Y : nat -> Ts -> R) (alpha : nat -> Ts -> R)
{adaptY : IsAdapted borel_sa Y F} (adapt_alpha : IsAdapted borel_sa alpha F)
(alpha_pos:forall n x, 0 <= alpha n x)
(alpha_one:forall n x, 0 <= 1 - alpha n x ) :
- almost prts (fun omega : Ts => Lim_seq.is_lim_seq (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n : nat => alpha n omega))
+ almost prts (fun omega : Ts => Lim_seq.is_lim_seq (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n : nat => alpha n omega))
Rbar.p_infty) ->
rv_eq (Y 0%nat) (const C) ->
(forall n, rv_eq (Y (S n)) (rvplus (rvmult (rvminus (const 1) (alpha n)) (Y n)) (rvscale (gamma * C) (alpha n)))) ->
@@ -2042,10 +2042,10 @@ Lemma Dvoretzky_converge_Z (Z BB: nat -> Ts -> R) (alpha : nat -> Ts -> R)
(fun x : Ts =>
ConditionalExpectation.ConditionalExpectation prts (filt_sub n) (BB n) x =
Rbar.Finite (const 0 x))) ->
- almost prts (fun omega : Ts => Lim_seq.is_lim_seq (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n : nat => alpha n omega))
+ almost prts (fun omega : Ts => Lim_seq.is_lim_seq (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n : nat => alpha n omega))
Rbar.p_infty) ->
(exists (A2 : R),
- almost prts (fun omega => Rbar.Rbar_lt (Lim_seq.Lim_seq (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n : nat => rvsqr (alpha n) omega))) (Rbar.Finite A2))) ->
+ almost prts (fun omega => Rbar.Rbar_lt (Lim_seq.Lim_seq (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n : nat => rvsqr (alpha n) omega))) (Rbar.Finite A2))) ->
(exists (sigma : R), forall n, rv_le (rvsqr (BB n)) (const (Rsqr sigma))) ->
rv_eq (Z 0%nat) (const 0) ->
(forall n, rv_eq (Z (S n)) (rvplus (rvmult (rvminus (const 1) (alpha n)) (Z n)) (rvmult (BB n) (alpha n)))) ->
@@ -2276,6 +2276,10 @@ Proof.
typeclasses eauto.
}
specialize (H4 H5 svy1).
+ change (Rbar.is_finite
+ (Lim_seq.Lim_seq (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n : nat => FiniteExpectation prts (rvsqr (alpha n)))))
+).
+
rewrite (Lim_seq.Lim_seq_ext _ _ H4).
destruct alpha_sqr as [A2 alpha_sqr].
generalize (Dominated_convergence_almost
@@ -2303,7 +2307,7 @@ Proof.
unfold rvsum.
left.
generalize (Lim_seq_increasing_le
- (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n0 : nat => rvsqr (alpha n0) x))); intros.
+ (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n0 : nat => rvsqr (alpha n0) x))); intros.
cut_to H8.
--- specialize (H8 n).
generalize (Rbar.Rbar_le_lt_trans _ _ _ H8 H7); intros.
@@ -2359,7 +2363,7 @@ Proof.
simpl.
unfold rvsum.
rewrite Rabs_right.
- ** generalize (Lim_seq_increasing_le (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n0 : nat => rvsqr (alpha n0) x))); intros.
+ ** generalize (Lim_seq_increasing_le (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n0 : nat => rvsqr (alpha n0) x))); intros.
cut_to H8.
--- specialize (H8 n).
generalize (Rbar.Rbar_le_lt_trans _ _ _ H8 H7); intros.
@@ -3406,10 +3410,10 @@ Lemma list_inter_prob_bound (l : list (event dom * R)) :
(forall n, independent_sas prts (filt_sub n)
(pullback_rv_sub dom cod (X (S n)) (rvX (S n)))) ->
(forall n, FiniteExpectation prts (BB ∘ X n) = 0) ->
- almost prts (fun omega : Ts => Lim_seq.is_lim_seq (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n : nat => alpha n omega))
+ almost prts (fun omega : Ts => Lim_seq.is_lim_seq (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n : nat => alpha n omega))
Rbar.p_infty) ->
(exists (A2 : R),
- almost prts (fun omega => Rbar.Rbar_lt (Lim_seq.Lim_seq (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n : nat => rvsqr (alpha n) omega))) (Rbar.Finite A2))) ->
+ almost prts (fun omega => Rbar.Rbar_lt (Lim_seq.Lim_seq (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n : nat => rvsqr (alpha n) omega))) (Rbar.Finite A2))) ->
(exists (sigma : R), forall n, rv_le (rvsqr (BB ∘ X n)) (const (Rsqr sigma))) ->
rv_eq (Z 0%nat) (const 0) ->
(forall n, rv_eq (Z (S n)) (rvplus (rvmult (rvminus (const 1) (alpha n)) (Z n)) (rvmult (BB ∘ X (S n)) (alpha n)))) ->
@@ -3626,7 +3630,7 @@ Lemma list_inter_prob_bound (l : list (event dom * R)) :
{isfe : forall n, IsFiniteExpectation prts (BB ∘ X n)}
(alpha_pos:forall n x, 0 <= alpha n x)
(alpha_one:forall n x, 0 <= 1 - alpha n x ) :
- almost prts (fun omega : Ts => Lim_seq.is_lim_seq (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n : nat => alpha n omega))
+ almost prts (fun omega : Ts => Lim_seq.is_lim_seq (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n : nat => alpha n omega))
Rbar.p_infty) ->
gamma + eps < 1 ->
gamma < 1 ->
@@ -3641,7 +3645,7 @@ Lemma list_inter_prob_bound (l : list (event dom * R)) :
(forall n, rv_eq (Z (S n)) (rvplus (rvmult (rvminus (const 1) (alpha n)) (Z n)) (rvmult (BB ∘ X (S n)) (alpha n)))) ->
(exists (A2 : R),
- almost prts (fun omega => Rbar.Rbar_lt (Lim_seq.Lim_seq (@Hierarchy.sum_n Hierarchy.R_AbelianGroup (fun n : nat => rvsqr (alpha n) omega))) (Rbar.Finite A2))) ->
+ almost prts (fun omega => Rbar.Rbar_lt (Lim_seq.Lim_seq (@Hierarchy.sum_n (Hierarchy.AbelianGroup.AbelianMonoid Hierarchy.R_AbelianGroup) (fun n : nat => rvsqr (alpha n) omega))) (Rbar.Finite A2))) ->
(exists (sigma : R), forall n, rv_le (rvsqr (BB ∘ X n)) (const (Rsqr sigma))) ->
(forall N, (N >= tk)%nat ->
forall omega,
diff --git a/coq/QLearn/slln.v b/rocq/QLearn/slln.v
similarity index 99%
rename from coq/QLearn/slln.v
rename to rocq/QLearn/slln.v
index 9bf0ea5c..e3385f7a 100644
--- a/coq/QLearn/slln.v
+++ b/rocq/QLearn/slln.v
@@ -1,4 +1,3 @@
-Require Import Qreals.
Require Import Lra Lia Reals RealAdd RandomVariableL2 Coquelicot.Coquelicot.
Require Import Morphisms FiniteType List ListAdd Permutation infprod Almost NumberIso.
Require Import Sums SimpleExpectation PushNeg.
@@ -6,7 +5,6 @@ Require Import EquivDec.
Require Import Classical.
Require Import ClassicalChoice.
Require Import IndefiniteDescription ClassicalDescription.
-Require QArith.
Require Import BorelSigmaAlgebra.
Require Import utils.Utils.
Require Import ConditionalExpectation.
@@ -167,7 +165,7 @@ Lemma ash_6_1_2 {a x : nat -> R} {x0 : R}(ha : forall n, 0 <= a n)
(hb1 : forall n, 0 < sum_f_R0 a n)(hb2 : is_lim_seq (sum_f_R0 a) p_infty) (hx : is_lim_seq x x0):
is_lim_seq (fun n => (sum_f_R0 (fun j => a j * x j) n)/(sum_f_R0 a n)) x0.
Proof.
- pose (A := fun (n j : nat) => if (le_dec j n) then (a j)/(sum_f_R0 a n) else 0).
+ pose (A := fun (n j : nat) => if (Compare_dec.le_dec j n) then (a j)/(sum_f_R0 a n) else 0).
assert (Apos: forall n j, 0 <= A n j).
{
intros.
@@ -204,7 +202,7 @@ Lemma ash_6_1_2 {a x : nat -> R} {x0 : R}(ha : forall n, 0 <= a n)
+ now rewrite Lim_seq_const.
+ exists n; intros.
rewrite sum_n_ext with
- (b := (fun j : nat => (if le_dec j n then (((a j)*(x j)) / sum_f_R0 a n) else 0))).
+ (b := (fun j : nat => (if Compare_dec.le_dec j n then (((a j)*(x j)) / sum_f_R0 a n) else 0))).
* rewrite <- sum_n_sum_f_clipped with (N := n); try lia.
rewrite sum_n_Reals.
unfold Rdiv.
@@ -309,7 +307,7 @@ Proof.
cut_to H3; trivial.
* eapply (is_lim_seq_ext _ _ x0 _ H3).
* intros.
- destruct (lt_dec 0 n).
+ destruct (Compare_dec.lt_dec 0 n).
-- specialize (hb1 (n-1)%nat).
replace (S (n-1)) with (n) in hb1 by lia.
lra.
@@ -374,7 +372,7 @@ Lemma ash_6_1_3_strong1 {b x : nat -> R} (hb1 : forall n, 0 < b n <= b (S n)) (h
(hx : ex_series x):
is_lim_seq (fun n => (sum_n_m (fun j => b j * x j) 1 n)/(b n)) 0.
Proof.
- pose (bb := fun n => if (lt_dec 0 n) then (b n) else 0).
+ pose (bb := fun n => if (Compare_dec.lt_dec 0 n) then (b n) else 0).
generalize (@ash_6_1_3 bb x); intros.
cut_to H; trivial.
- apply is_lim_seq_ext with (v := fun n => sum_n_m (fun j => b j * x j) 1 (S n) / (b (S n))) in H.
@@ -1693,7 +1691,7 @@ Qed.
rv_unfold.
rewrite Rmax_list_app by now simpl.
unfold Rmax.
- rewrite plus_0_l.
+ rewrite Nat.add_0_l.
destruct (Rle_dec (Rmax_list (map (fun n : nat => F n a) (seq 0 (S k)))) (F (S k) a)); trivial.
}
apply IsFiniteExpectation_case.
@@ -2014,7 +2012,7 @@ Proof.
unfold Sum, rvsum. rewrite sum_Sn. unfold plus. simpl.
rewrite Rplus_comm.
unfold Rminus; rewrite Rplus_assoc.
- replace (sum_n (fun n0 : nat => X (n0+m)%nat w) j + - sum_n (fun n0 : nat => X (n0+m)%nat w) j) with 0 by lra.
+ replace (@sum_n R_AbelianGroup (fun n0 : nat => X (n0+m)%nat w) j + - @sum_n R_AbelianGroup (fun n0 : nat => X (n0+m)%nat w) j) with 0 by lra.
rewrite Rplus_0_r.
match_destr.
- match_destr.
@@ -2867,12 +2865,12 @@ Qed.
Qed.
Lemma sum_shift_diff (X : nat -> R) (m a : nat) :
- sum_n X (a + S m) - sum_n X m =
- sum_n (fun n0 : nat => X (n0 + S m)%nat) a.
+ @sum_n R_AbelianGroup X (a + S m) - @sum_n R_AbelianGroup X m =
+ @sum_n R_AbelianGroup (fun n0 : nat => X (n0 + S m)%nat) a.
Proof.
rewrite <- sum_n_m_shift.
unfold sum_n.
- rewrite (@sum_split _ _ _ _ m); try lia.
+ rewrite (@sum_split R_AbelianGroup _ _ _ m); try lia.
unfold plus; simpl.
lra.
Qed.
@@ -3997,7 +3995,7 @@ Qed.
match_destr.
}
unfold independent_event_collection in indep.
- destruct (in_dec NPeano.Nat.eq_dec 0%nat l).
+ destruct (in_dec Nat.eq_dec 0%nat l).
- pose (ll := 1%nat :: map (fun n => match n with
| 0%nat => n
| S n' => S n
@@ -4024,13 +4022,28 @@ Qed.
etransitivity; [| etransitivity]; [| apply indep' |].
+ apply ps_proper.
unfold ll.
- rewrite perm.
- simpl.
+ etransitivity.
+ { apply list_inter_equivlist_proper.
+ apply Permutation_equivlist.
+ apply Permutation_map.
+ apply perm.
+ }
+ etransitivity; cycle 1.
+ {
+ apply list_inter_equivlist_proper.
+ apply Permutation_equivlist.
+ apply Permutation_map.
+ apply Permutation_cons; [reflexivity |].
+ apply Permutation_map.
+ symmetry.
+ apply perm.
+ }
+ repeat rewrite map_cons.
repeat rewrite list_inter_cons.
rewrite event_inter_assoc.
apply event_inter_proper.
- * rewrite event_inter_comm.
- apply H6.
+ * red; simpl.
+ now rewrite pre_event_inter_comm.
* apply list_inter_proper.
rewrite map_map.
apply Forall2_map_f.
@@ -4924,7 +4937,7 @@ Qed.
- apply collection_take_preserves_disjoint.
intros ????[??][??]; simpl in *.
apply H.
- apply le_antisym.
+ apply Nat.le_antisymm.
+ assert (INR n1 < INR (S n2)) by (rewrite S_INR; lra).
apply INR_lt in H4.
lia.
@@ -5292,7 +5305,7 @@ Qed.
apply Rbar_finite_eq.
lra.
-- intros.
- rewrite sum_split with (m := n); try lia.
+ rewrite (@sum_split R_AbelianGroup) with (m := n); try lia.
rewrite sum_n_m_ext_loc with (b := fun _ => zero).
++ rewrite sum_n_m_const_zero.
rewrite plus_zero_l.
@@ -5324,7 +5337,7 @@ Qed.
apply ps_pos.
+ intros.
unfold sum_n.
- rewrite sum_split with (m := n); try lia.
+ rewrite (@sum_split R_AbelianGroup) with (m := n); try lia.
reflexivity.
Qed.
@@ -5387,7 +5400,7 @@ Qed.
-- replace (n0 + S (S a))%nat with (S (n0 + S a)) by lia.
rewrite sum_Rbar_n_finite_sum_n.
unfold sum_n.
- rewrite sum_split with (m := a); try lia.
+ rewrite (@sum_split R_AbelianGroup) with (m := a); try lia.
rewrite plus_comm.
rewrite sum_n_m_ext_loc with (b := (fun _ => @zero R_AbelianGroup)).
++ rewrite sum_n_m_const_zero.
@@ -5614,7 +5627,7 @@ Qed.
forall n m,
sum_n_m f m (m + n)%nat =
sum_n (fun k =>
- (if (le_dec m k) then 1 else 0) * (f k))
+ (if (Compare_dec.le_dec m k) then 1 else 0) * (f k))
(m + n)%nat.
Proof.
intros.
@@ -5648,7 +5661,7 @@ Qed.
forall n m,
Rbar.Finite(sum_n_m f m (m + n)%nat) =
sum_Rbar_n (fun k =>
- (if (le_dec m k) then 1 else 0) * (f k))
+ (if (Compare_dec.le_dec m k) then 1 else 0) * (f k))
(S (m + n)).
Proof.
intros.
@@ -5659,7 +5672,7 @@ Qed.
Lemma sum_inv_sq_Elim :
forall m,
Rbar_le
- (ELim_seq (sum_Rbar_n (fun j : nat => (if le_dec m j then 1 else 0) / (INR (S j))²)))
+ (ELim_seq (sum_Rbar_n (fun j : nat => (if Compare_dec.le_dec m j then 1 else 0) / (INR (S j))²)))
(2 / (INR (S m))).
Proof.
intros.
@@ -5825,7 +5838,7 @@ Qed.
ELim_seq
(sum_Rbar_n
(fun n0 : nat =>
- (if le_dec n0 x then 1 else 0) *
+ (if Compare_dec.le_dec n0 x then 1 else 0) *
(f n0))) =
(sum_n f x).
Proof.
@@ -6576,7 +6589,7 @@ Qed.
(ELim_seq
(sum_Rbar_n
(fun n0 : nat =>
- (if (le_dec n0 x) then 1 else 0) *
+ (if (Compare_dec.le_dec n0 x) then 1 else 0) *
FiniteExpectation Prts
(rvmult (rvsqr (X 0%nat))
(EventIndicator
@@ -6595,7 +6608,7 @@ Qed.
(event_inter
(event_lt dom (rvabs (X 0%nat)) (INR i + 1))
(event_ge dom (rvabs (X 0%nat)) (INR i)))))))
- (ELim_seq (sum_Rbar_n (fun j => (if (le_dec i j) then 1 else 0)/ (INR (S j))²))))).
+ (ELim_seq (sum_Rbar_n (fun j => (if (Compare_dec.le_dec i j) then 1 else 0)/ (INR (S j))²))))).
++ apply bounded_is_finite with (a := 0) (b := 2 * FiniteExpectation Prts (rvabs (X 0%nat))).
--- apply ELim_seq_nneg; intros.
apply sum_Rbar_n_pos; intros.
@@ -6816,10 +6829,10 @@ Qed.
(classic_dec
(event_inter (event_lt dom (rvabs (X 0%nat)) (INR x + 1))
(event_ge dom (rvabs (X 0%nat)) (INR x)))))))
- ((sum_Rbar_n (fun j : nat => (if le_dec x j then 1 else 0) / (INR (S j))²)) n))).
+ ((sum_Rbar_n (fun j : nat => (if Compare_dec.le_dec x j then 1 else 0) / (INR (S j))²)) n))).
** rewrite ELim_seq_scal_l; trivial.
unfold Rdiv.
- assert (is_finite (ELim_seq (sum_Rbar_n (fun j : nat => (if le_dec x j then 1 else 0) * / (INR (S j))²)))).
+ assert (is_finite (ELim_seq (sum_Rbar_n (fun j : nat => (if Compare_dec.le_dec x j then 1 else 0) * / (INR (S j))²)))).
{
apply bounded_is_finite with (a := 0) (b := 2 / INR (S x)).
- apply ELim_seq_nneg; intros.
@@ -6837,7 +6850,7 @@ Qed.
now rewrite <- H12.
** intros.
assert (forall j,
- Rbar.Finite ((if le_dec x j then 1 else 0) *
+ Rbar.Finite ((if Compare_dec.le_dec x j then 1 else 0) *
(FiniteExpectation Prts
(rvmult (rvsqr (X 0%nat))
(EventIndicator
@@ -6847,7 +6860,7 @@ Qed.
(rvmult (rvsqr (X 0%nat))
(EventIndicator
(classic_dec (event_inter (event_lt dom (rvabs (X 0%nat)) (INR x + 1)) (event_ge dom (rvabs (X 0%nat)) (INR x))))))) *
- ((if le_dec x j then 1 else 0) / (INR (S j))²))).
+ ((if Compare_dec.le_dec x j then 1 else 0) / (INR (S j))²))).
{
intros.
apply Rbar_finite_eq.
@@ -6870,7 +6883,7 @@ Qed.
(EventIndicator
(classic_dec (event_inter (event_lt dom (rvabs (X 0%nat)) (INR x + 1))
(event_ge dom (rvabs (X 0%nat)) (INR x)))))))
- ((if le_dec x x0 then 1 else 0) / (INR (S x0))²)).
+ ((if Compare_dec.le_dec x x0 then 1 else 0) / (INR (S x0))²)).
+++ rewrite sum_n_scal_l.
reflexivity.
+++ intros.
@@ -6897,7 +6910,7 @@ Qed.
rewrite ELim_seq_ext with
(v := (sum_Rbar_n
(fun n0 : nat =>
- (if le_dec n0 x then 1 else 0) *
+ (if Compare_dec.le_dec n0 x then 1 else 0) *
(FiniteExpectation Prts
(rvmult (rvsqr (X 0%nat))
(EventIndicator
diff --git a/coq/QLearn/sumtest.v b/rocq/QLearn/sumtest.v
similarity index 100%
rename from coq/QLearn/sumtest.v
rename to rocq/QLearn/sumtest.v
diff --git a/coq/QLearn/uniform_converge.v b/rocq/QLearn/uniform_converge.v
similarity index 100%
rename from coq/QLearn/uniform_converge.v
rename to rocq/QLearn/uniform_converge.v
diff --git a/coq/QLearn/vecslln.v b/rocq/QLearn/vecslln.v
similarity index 99%
rename from coq/QLearn/vecslln.v
rename to rocq/QLearn/vecslln.v
index be18ee32..45fd50fb 100644
--- a/coq/QLearn/vecslln.v
+++ b/rocq/QLearn/vecslln.v
@@ -1,4 +1,3 @@
-Require Import QArith.Qreals.
Require Import Lra Lia Reals RealAdd RandomVariableL2 Coquelicot.Coquelicot.
Require Import Morphisms FiniteType List ListAdd Permutation infprod Almost NumberIso.
Require Import Sums SimpleExpectation PushNeg.
@@ -348,7 +347,7 @@ Section conv_as.
intros.
apply almost_bounded_forall; trivial.
- intros.
- apply lt_dec.
+ apply Compare_dec.lt_dec.
- intros.
revert H0.
apply is_lim_seq_ext.
@@ -585,7 +584,7 @@ Section vec_cauchy.
Lemma nth_exist_join {T} {size:nat} {P} (f:forall (i:nat) (pf:(i < size)%nat), exists (t:T), P i pf t) :
exists (ts:list T), length ts = size /\ forall i pf d, P i pf (nth i ts d).
Proof.
- destruct (nth_exist_join_aux f _ (le_refl _)) as [??]; firstorder.
+ destruct (nth_exist_join_aux f _ (Nat.le_refl _)) as [??]; firstorder.
Qed.
Lemma Hnorm_vector0 (x : vector R 0) : hilbert.Hnorm x = 0.
@@ -806,7 +805,7 @@ Section ash.
(hb1 : forall n, 0 < sum_f_R0 a n)(hb2 : is_lim_seq (sum_f_R0 a) p_infty) (hx : is_lim_seq x x0):
is_lim_seq (fun n => (sum_f_R0 (fun j => a j * x j) n)/(sum_f_R0 a n)) x0.
Proof.
- pose (A := fun (n j : nat) => if (le_dec j n) then (a j)/(sum_f_R0 a n) else 0).
+ pose (A := fun (n j : nat) => if (Compare_dec.le_dec j n) then (a j)/(sum_f_R0 a n) else 0).
assert (Apos: forall n j, 0 <= A n j).
{
intros.
@@ -843,7 +842,7 @@ Section ash.
+ now rewrite Lim_seq_const.
+ exists n; intros.
rewrite sum_n_ext with
- (b := (fun j : nat => (if le_dec j n then (((a j)*(x j)) / sum_f_R0 a n) else 0))).
+ (b := (fun j : nat => (if Compare_dec.le_dec j n then (((a j)*(x j)) / sum_f_R0 a n) else 0))).
* rewrite <- sum_n_sum_f_clipped with (N := n); try lia.
rewrite sum_n_Reals.
unfold Rdiv.
@@ -948,7 +947,7 @@ Section ash.
cut_to H3; trivial.
* eapply (is_lim_seq_ext _ _ x0 _ H3).
* intros.
- destruct (lt_dec 0 n).
+ destruct (Compare_dec.lt_dec 0 n).
-- specialize (hb1 (n-1)%nat).
replace (S (n-1)) with (n) in hb1 by lia.
lra.
@@ -1013,7 +1012,7 @@ Section ash.
(hx : ex_series x):
is_lim_seq (fun n => (sum_n_m (fun j => b j * x j) 1 n)/(b n)) 0.
Proof.
- pose (bb := fun n => if (lt_dec 0 n) then (b n) else 0).
+ pose (bb := fun n => if (Compare_dec.lt_dec 0 n) then (b n) else 0).
generalize (@ash_6_1_3 bb x); intros.
cut_to H; trivial.
- apply is_lim_seq_ext with (v := fun n => sum_n_m (fun j => b j * x j) 1 (S n) / (b (S n))) in H.
@@ -2611,7 +2610,7 @@ End ash.
rv_unfold.
rewrite Rmax_list_app by now simpl.
unfold Rmax.
- rewrite plus_0_l.
+ rewrite Nat.add_0_l.
destruct (Rle_dec (Rmax_list (map (fun n : nat => F n a) (seq 0 (S k)))) (F (S k) a)); trivial.
}
apply IsFiniteExpectation_case.
@@ -2668,7 +2667,7 @@ End ash.
rv_unfold.
rewrite Rmax_list_app by now simpl.
unfold Rmax.
- rewrite plus_0_l.
+ rewrite Nat.add_0_l.
destruct (Rle_dec (Rmax_list (map (fun n : nat => F n a) (seq 0 (S k)))) (F (S k) a)); trivial.
}
apply IsFiniteExpectation_case; trivial.
@@ -2695,7 +2694,7 @@ End ash.
apply in_map_iff.
exists j'; split; trivial.
apply in_seq; lia.
- - destruct (le_dec k' k).
+ - destruct (Compare_dec.le_dec k' k).
+ specialize (IHk l).
eapply Rle_trans.
* apply IHk.
@@ -3257,7 +3256,7 @@ End ash.
+ apply cond_pos.
+ exists q.
split.
- * apply Rlt_Qlt.
+ * apply Qreals.Rlt_Qlt.
unfold QArith_base.inject_Z.
unfold Q2R.
simpl.
@@ -3270,7 +3269,7 @@ End ash.
- intros [eps [epos HH]].
assert (qepspos: 0 < Q2R eps).
{
- apply Qlt_Rlt in epos.
+ apply Qreals.Qlt_Rlt in epos.
now rewrite RMicromega.Q2R_0 in epos.
}
exists (mkposreal (Q2R eps) qepspos).
@@ -3285,23 +3284,23 @@ End ash.
destruct (Rlt_dec 0 (Q2R i)).
- assert (QArith_base.Qlt {| QArith_base.Qnum := 0; QArith_base.Qden := 1 |} i).
{
- apply Rlt_Qlt.
+ apply Qreals.Rlt_Qlt.
now rewrite RMicromega.Q2R_0.
}
eapply (sa_proper _ (fun omega => (forall N : nat,
exists n m : nat,
(n >= N)%nat /\ (m >= N)%nat /\ hilbert.Hnorm (minus (X n omega) (X m omega)) >= Q2R i))).
- + firstorder.
+ + red; simpl; intuition.
+ apply sa_pre_countable_inter; intros N.
now apply (vec_sa_sigma_not_cauchy X (mkposreal _ r)).
- eapply sa_proper; try apply sa_none.
assert (~ QArith_base.Qlt {| QArith_base.Qnum := 0; QArith_base.Qden := 1 |} i).
{
intros qlt.
- apply Qlt_Rlt in qlt.
+ apply Qreals.Qlt_Rlt in qlt.
now rewrite RMicromega.Q2R_0 in qlt.
}
- firstorder.
+ red; intuition.
Qed.
Lemma vec_sa_sigma_cauchy_descending {size:nat} (X : nat -> Ts -> vector R size )(eps : posreal)
@@ -3660,7 +3659,7 @@ End ash.
Qed.
Lemma vec_sum_n_m_shift {size:nat} (α : nat -> vector R size) (k n0 : nat) :
- sum_n_m α k (n0 + k)%nat = sum_n (fun n1 : nat => α (n1 + k)%nat) n0.
+ @sum_n_m Rvector_AbelianGroup α k (n0 + k)%nat = @sum_n Rvector_AbelianGroup (fun n1 : nat => α (n1 + k)%nat) n0.
Proof.
unfold sum_n.
induction n0.
@@ -3675,9 +3674,10 @@ End ash.
Qed.
Lemma vec_sum_shift_diff {size:nat} (X : nat -> vector R size) (m a : nat) :
- minus (sum_n X (a + S m)) (sum_n X m) =
- sum_n (fun n0 : nat => X (n0 + S m)%nat) a.
+ minus (@sum_n Rvector_AbelianGroup X (a + S m)) (@sum_n Rvector_AbelianGroup X m) =
+ @sum_n Rvector_AbelianGroup (fun n0 : nat => X (n0 + S m)%nat) a.
Proof.
+ simpl.
rewrite <- vec_sum_n_m_shift.
unfold sum_n.
rewrite (@sum_split _ _ _ _ m); try lia.
@@ -3806,7 +3806,7 @@ End ash.
rewrite minus_eq_zero in r.
rewrite (@hilbert.norm_zero) in r.
generalize (is_pos_div_2 eps); intros; lra.
- + assert (n > N)%nat by (destruct H; try lia;firstorder).
+ + assert (n > N)%nat by (destruct H; try lia;intuition).
exists (n - (S N))%nat.
unfold vecrvminus.
now replace (n - S N + S N)%nat with (n) by lia.
diff --git a/coq/lib_utils/LibUtils.v b/rocq/lib_utils/LibUtils.v
similarity index 100%
rename from coq/lib_utils/LibUtils.v
rename to rocq/lib_utils/LibUtils.v
diff --git a/coq/lib_utils/LibUtilsAssoc.v b/rocq/lib_utils/LibUtilsAssoc.v
similarity index 100%
rename from coq/lib_utils/LibUtilsAssoc.v
rename to rocq/lib_utils/LibUtilsAssoc.v
diff --git a/coq/lib_utils/LibUtilsBag.v b/rocq/lib_utils/LibUtilsBag.v
similarity index 98%
rename from coq/lib_utils/LibUtilsBag.v
rename to rocq/lib_utils/LibUtilsBag.v
index 5ee53180..507f8c42 100644
--- a/coq/lib_utils/LibUtilsBag.v
+++ b/rocq/lib_utils/LibUtilsBag.v
@@ -17,8 +17,6 @@
(** This module provides support for bags (or multisets). *)
Require Import Arith.
-Require Import Min.
-Require Import Max.
Require Import ZArith.
Require Import Lia.
Require Import Permutation.
@@ -843,8 +841,8 @@ Section Bag.
generalize (mult x0 a); intro; auto with arith.
rewrite H2.
rewrite IHl; rewrite H2.
- generalize (mult l a); intro; rewrite <- succ_min_distr; auto with arith.
- rewrite H0; rewrite min_0_l; simpl.
+ generalize (mult l a); intro; rewrite <- Nat.succ_min_distr; auto with arith.
+ rewrite H0; rewrite Nat.min_0_l; simpl.
elim (equiv_dec a a); intro; try congruence.
rewrite IHl; rewrite H0; auto with arith.
rewrite IHl.
@@ -878,10 +876,10 @@ Section Bag.
+ rewrite H in *; clear H.
case (equiv_dec x x); try congruence.
intros.
- apply gt_Sn_O.
+ auto with arith.
+ case (equiv_dec x a).
* intros.
- apply gt_Sn_O.
+ auto with arith.
* intros.
apply IHl.
assumption.
@@ -896,7 +894,7 @@ Section Bag.
apply mult_pos_equiv_in.
apply mult_pos_equiv_in in Hx.
rewrite bdistinct_mult.
- apply min_case; auto.
+ apply Nat.min_case; auto.
Qed.
Lemma bdistinct_nil {l:list A} :
@@ -1033,7 +1031,7 @@ Section Bag.
Proof.
intros; revert s.
induction t; simpl.
- intro. rewrite <- minus_n_O; reflexivity.
+ intro. rewrite Nat.sub_0_r; reflexivity.
intros; rewrite IHt.
destruct (equiv_dec x a).
rewrite e in *; clear e.
@@ -1103,7 +1101,7 @@ Section Bag.
forall (n n0:nat), n0 - (n0 - n) = min n0 n.
Proof.
induction n.
- intros; rewrite min_0_r; rewrite <- minus_n_O; rewrite minus_diag; reflexivity.
+ intros; rewrite Nat.min_0_r; rewrite Nat.sub_0_r; rewrite Nat.sub_diag; reflexivity.
intros.
generalize eq_nat_dec; intro.
elim (H n n0); intro; clear H.
@@ -1120,7 +1118,7 @@ Section Bag.
forall (n n0:nat), n0 + (n - n0) = max n0 n.
Proof.
induction n; intros.
- rewrite max_0_r; auto with arith.
+ rewrite Nat.max_0_r; auto with arith.
generalize eq_nat_dec; intro.
elim (H n n0); intro; clear H.
rewrite a in *; clear a.
@@ -1171,7 +1169,7 @@ Section Bag.
intro.
rewrite (bmin_mult s t a).
rewrite (bmin_mult t s a).
- rewrite min_comm.
+ rewrite Nat.min_comm.
reflexivity.
Qed.
@@ -1183,7 +1181,7 @@ Section Bag.
intro.
rewrite (bmax_mult s t a).
rewrite (bmax_mult t s a).
- rewrite max_comm.
+ rewrite Nat.max_comm.
reflexivity.
Qed.
@@ -1255,7 +1253,7 @@ Section Bag.
revert H IHl1 H0; generalize (mult l1 a); generalize (mult l2 a); intros.
assert (n = 0).
+ apply min_zero with (n2 := 0); assumption.
- + clear H; rewrite H1 in *; rewrite plus_0_r in *; assumption.
+ + clear H; rewrite H1 in *; rewrite Nat.add_0_r in *; assumption.
- apply IHl1; assumption.
Qed.
@@ -1281,7 +1279,7 @@ Section Bag.
rewrite H; assumption.
- assert (n = 0).
apply min_zero with (n2 := 0); assumption.
- rewrite H; rewrite plus_0_r; assumption.
+ rewrite H; rewrite Nat.add_0_r; assumption.
- assert (1 <= n).
apply min_one_yields_one; assumption.
assert (1 <= n0).
@@ -1317,17 +1315,17 @@ Section Bag.
(mult l1 x = 0) \/ (mult l2 x = 0) -> mult (l1 min-b l2) x = 0.
Proof.
induction l2; simpl; intros; rewrite bmin_mult; simpl.
- rewrite min_0_r; reflexivity.
+ rewrite Nat.min_0_r; reflexivity.
revert H; elim (equiv_dec x a); intros.
rewrite a0 in *; clear a0.
elim H; intro; clear H.
rewrite H0 in *.
generalize (mult l2 a); auto.
rewrite H0 in *.
- generalize (mult l1 a); intro; rewrite min_0_r; reflexivity.
+ generalize (mult l1 a); intro; rewrite Nat.min_0_r; reflexivity.
elim H; intro; clear H; rewrite H0.
generalize (mult l1 a); auto.
- generalize (mult l1 a); intro; rewrite min_0_r; reflexivity.
+ generalize (mult l1 a); intro; rewrite Nat.min_0_r; reflexivity.
Qed.
Lemma bdistinct_over_bmin:
@@ -1350,7 +1348,7 @@ Section Bag.
assert (n = 0); try (apply (min_zero n 0); assumption).
rewrite H0; auto with arith.
assert (n0 = 0); try (apply (min_zero n0 0); assumption).
- rewrite H0; rewrite min_0_r; auto with arith.
+ rewrite H0; rewrite Nat.min_0_r; auto with arith.
simpl; generalize (compare_either n n0); intro.
elim H0; intro; clear H0.
assert (min n n0 = n). try (rewrite min_l; [reflexivity|assumption]).
diff --git a/coq/lib_utils/LibUtilsBindings.v b/rocq/lib_utils/LibUtilsBindings.v
similarity index 99%
rename from coq/lib_utils/LibUtilsBindings.v
rename to rocq/lib_utils/LibUtilsBindings.v
index 22b7139c..28b739d4 100644
--- a/coq/lib_utils/LibUtilsBindings.v
+++ b/rocq/lib_utils/LibUtilsBindings.v
@@ -43,20 +43,21 @@ Require Import LibUtilsStringAdd.
Section Bindings.
Class ODT {K:Type}
- := mkODT { ODT_eqdec:>EqDec K eq;
+ := mkODT { ODT_eqdec::EqDec K eq;
ODT_lt:K -> K -> Prop;
- ODT_lt_strorder:>StrictOrder ODT_lt;
+ ODT_lt_strorder::StrictOrder ODT_lt;
ODT_lt_dec: forall (a b:K), {ODT_lt a b} + {~ODT_lt a b};
ODT_compare:K -> K -> comparison;
ODT_compare_spec: forall x y : K,
CompareSpec (eq x y) (ODT_lt x y) (ODT_lt y x) (ODT_compare x y) }.
+
Generalizable Variables K.
Context `{odt:@ODT K}.
Lemma ODT_lt_irr (k:K) :
~(ODT_lt k k).
- Proof.
+ Proof.
apply irreflexivity.
Qed.
diff --git a/coq/lib_utils/LibUtilsBindingsNat.v b/rocq/lib_utils/LibUtilsBindingsNat.v
similarity index 92%
rename from coq/lib_utils/LibUtilsBindingsNat.v
rename to rocq/lib_utils/LibUtilsBindingsNat.v
index bd820896..0de72e7d 100644
--- a/coq/lib_utils/LibUtilsBindingsNat.v
+++ b/rocq/lib_utils/LibUtilsBindingsNat.v
@@ -18,13 +18,12 @@
natural number. *)
Require Import Arith.
-Require Import NPeano.
Require Import LibUtilsBindings.
Section BindingsNat.
Global Program Instance ODT_nat : (@ODT nat)
- := mkODT _ _ lt Nat.lt_strorder lt_dec Nat.compare _.
+ := mkODT _ _ lt Nat.lt_strorder Compare_dec.lt_dec Nat.compare _.
Next Obligation.
simpl.
apply Nat.compare_spec.
diff --git a/coq/lib_utils/LibUtilsClosure.v b/rocq/lib_utils/LibUtilsClosure.v
similarity index 100%
rename from coq/lib_utils/LibUtilsClosure.v
rename to rocq/lib_utils/LibUtilsClosure.v
diff --git a/coq/lib_utils/LibUtilsCompat.v b/rocq/lib_utils/LibUtilsCompat.v
similarity index 100%
rename from coq/lib_utils/LibUtilsCompat.v
rename to rocq/lib_utils/LibUtilsCompat.v
diff --git a/coq/lib_utils/LibUtilsCoqLibAdd.v b/rocq/lib_utils/LibUtilsCoqLibAdd.v
similarity index 98%
rename from coq/lib_utils/LibUtilsCoqLibAdd.v
rename to rocq/lib_utils/LibUtilsCoqLibAdd.v
index ed65514a..d887466e 100644
--- a/coq/lib_utils/LibUtilsCoqLibAdd.v
+++ b/rocq/lib_utils/LibUtilsCoqLibAdd.v
@@ -30,7 +30,6 @@ Require Import EquivDec.
Require Import Equivalence.
Require Import Peano_dec.
Require Import ZArith.
-Require Import Zdigits.
Require Import Znat.
Require Import Recdef.
Require Import Compare_dec.
@@ -186,7 +185,7 @@ Section CoqLibAdd.
Lemma nin_app_or (x:A) a b :
(~ In x (a ++ b)) <-> (~ In x a /\ ~ In x b).
Proof.
- intuition. apply in_app_or in H0; intuition.
+ intuition (auto with *). apply in_app_or in H0; intuition.
Qed.
Lemma in_in_app_false {l l1 l2} :
@@ -463,7 +462,7 @@ Section CoqLibAdd.
(forallb f l1 = true /\ forallb f l2 = true).
Proof.
repeat rewrite forallb_forall.
- intuition; rewrite in_app_iff in *; intuition.
+ intuition (auto with *); rewrite in_app_iff in *; intuition (auto with *).
Qed.
Lemma forallb_map {A B} f (mf:A->B) m : forallb f (map mf m) = forallb ((fun x => f (mf x))) m.
@@ -575,7 +574,7 @@ Section CoqLibAdd.
reflexivity.
assert (exists (n3:nat), min (S n1) (S n2) = (S n3)).
exists (min n1 n2).
- rewrite Min.succ_min_distr; reflexivity.
+ rewrite Nat.succ_min_distr; reflexivity.
elim H0; intros.
congruence.
Qed.
@@ -622,8 +621,8 @@ Section CoqLibAdd.
revert x0; induction l; simpl; intros; try lia.
rewrite (IHl (n0 * f a + x0)); simpl.
rewrite (fold_left_arith_dist1 (f a + 0)).
- rewrite mult_plus_distr_l.
- rewrite mult_plus_distr_l.
+ rewrite Nat.mul_add_distr_l.
+ rewrite Nat.mul_add_distr_l.
lia.
Qed.
@@ -633,7 +632,7 @@ Section CoqLibAdd.
Proof.
generalize 0.
revert x0; induction l; simpl; intros; try lia.
- assert (f a + n = n + f a) by apply plus_comm.
+ assert (f a + n = n + f a) by apply Nat.add_comm.
rewrite H; clear H.
rewrite (IHl x0 (n + f a)); reflexivity.
Qed.
@@ -803,7 +802,7 @@ Section CoqLibAdd.
Context {A:Type}.
Function iter_cost (optim: A -> A) (cost: A -> nat) (p: A) { measure cost p } :=
let p' := optim p in
- if lt_dec (cost p') (cost p)
+ if Compare_dec.lt_dec (cost p') (cost p)
then iter_cost optim cost p'
else p.
Proof.
@@ -1015,3 +1014,4 @@ Ltac string_eqdec_to_equiv :=
Ltac string_dec_to_equiv :=
replace string_dec with (equiv_dec (EqDec:=string_dec)) in * by trivial.
+
diff --git a/coq/lib_utils/LibUtilsDigits.v b/rocq/lib_utils/LibUtilsDigits.v
similarity index 92%
rename from coq/lib_utils/LibUtilsDigits.v
rename to rocq/lib_utils/LibUtilsDigits.v
index 49db3ea3..b52cd973 100644
--- a/coq/lib_utils/LibUtilsDigits.v
+++ b/rocq/lib_utils/LibUtilsDigits.v
@@ -61,13 +61,13 @@ Section prelude.
generalize (refl_equal n).
pattern n at 2 4 6 10, q; case q; [intro | intros m l e].
rewrite <- eq_rect_eq_nat; trivial.
- contradiction (le_Sn_n m); rewrite <- e; assumption.
+ contradiction (Nat.nle_succ_diag_l m); rewrite <- e; assumption.
replace (le_S n m p) with
(eq_rect _ (fun n0 => n <= n0) (le_S n m p) _ (refl_equal (S m))).
2:reflexivity.
generalize (refl_equal (S m)).
pattern (S m) at 1 3 4 6, q; case q; [intro Heq | intros m0 l HeqS].
- contradiction (le_Sn_n m); rewrite Heq; assumption.
+ contradiction (Nat.nle_succ_diag_l m); rewrite Heq; assumption.
injection HeqS; intro Heq; generalize l HeqS.
rewrite <- Heq; intros; rewrite <- eq_rect_eq_nat.
rewrite (IHp l0); reflexivity.
@@ -157,7 +157,7 @@ Section Digits.
split.
- unfold digits_to_nat in e1.
rewrite e1.
- rewrite mult_comm.
+ rewrite Nat.mul_comm.
rewrite <- Nat.div_mod; trivial.
lia.
- intros ? HH. destruct (rev x); simpl in * .
@@ -167,7 +167,7 @@ Section Digits.
simpl in *.
rewrite <- Nat.div_exact by lia.
rewrite <- e1.
- rewrite mult_comm.
+ rewrite Nat.mul_comm.
simpl.
lia.
+ auto.
@@ -192,8 +192,8 @@ Section Digits.
induction l; simpl; trivial.
intros.
apply IHl.
- apply plus_le_compat_r.
- apply mult_le_compat_r.
+ apply Nat.add_le_mono_r.
+ apply Nat.mul_le_mono_r.
trivial.
Qed.
@@ -209,9 +209,9 @@ Section Digits.
transitivity (acc*base).
* transitivity (acc * 1).
{ lia. }
- apply mult_le_compat_l.
+ apply Nat.mul_le_mono_l.
lia.
- * apply le_plus_l.
+ * apply Nat.le_add_r.
Qed.
Lemma digits_to_nat_aux_bound l c:
@@ -221,34 +221,28 @@ Section Digits.
induction l; simpl.
- split.
+ lia.
- + destruct (mult_O_le c (base*1)).
- * lia.
- * rewrite mult_comm.
- lia.
+ + lia.
- intros.
destruct (IHl (c * base + proj1_sig a)) as [le1 le2].
clear IHl.
split.
+ rewrite <- le1.
- rewrite mult_plus_distr_r.
- rewrite mult_assoc.
- apply le_plus_l.
- + eapply lt_le_trans; [apply le2 | ].
- repeat rewrite mult_plus_distr_r.
- repeat rewrite mult_assoc.
+ lia.
+ + eapply Nat.lt_le_trans; [apply le2 | ].
+ repeat rewrite Nat.mul_add_distr_r.
+ repeat rewrite Nat.mul_assoc.
repeat rewrite Nat.mul_1_l.
- rewrite plus_assoc_reverse.
- apply plus_le_compat_l.
+ rewrite <- Nat.add_assoc.
+ apply Nat.add_le_mono_l.
replace
(proj1_sig a * base ^ Datatypes.length l + base ^ Datatypes.length l)
with
((proj1_sig a +1) * base ^ Datatypes.length l).
- * apply mult_le_compat_r.
+ * apply Nat.mul_le_mono_r.
destruct a; simpl.
- rewrite plus_comm; simpl.
- apply lt_le_S.
- trivial.
- * rewrite mult_plus_distr_r, Nat.mul_1_l; trivial.
+ rewrite Nat.add_comm; simpl.
+ auto with arith.
+ * rewrite Nat.mul_add_distr_r, Nat.mul_1_l; trivial.
Qed.
Lemma digits_to_nat_aux_acc_inj_helper1 a b c n1 n2 :
@@ -260,71 +254,64 @@ Section Digits.
Proof.
intros ? ? ? lt1 ltn.
assert (le12:c * base * base ^ n2 + 0 <= c * base * base ^ n2 + a * base ^ n2).
- { apply plus_le_compat_l.
+ { apply Nat.add_le_mono_l.
apply Peano.le_0_n.
}
- rewrite plus_0_r in le12.
- rewrite mult_plus_distr_r in lt1.
- eapply le_lt_trans in lt1; try eapply le12.
+ rewrite Nat.add_0_r in le12.
+ rewrite Nat.mul_add_distr_r in lt1.
+ eapply Nat.le_lt_trans in lt1; try eapply le12.
assert (le13:(c * base + b + 1) * (base ^ n1)
<=
(c * base + base) * (base ^ n1 )).
{
- apply mult_le_compat_r.
- rewrite plus_assoc_reverse.
- apply plus_le_compat_l.
+ apply Nat.mul_le_mono_r.
+ rewrite <- Nat.add_assoc.
+ apply Nat.add_le_mono_l.
lia.
}
- eapply lt_le_trans in le13; try eapply lt1.
- rewrite (le_plus_minus n1 n2) in le13 by lia.
+ eapply Nat.lt_le_trans in le13; try eapply lt1.
+ rewrite <- (Nat.sub_add n1 n2) in le13 by lia.
rewrite Nat.pow_add_r in le13.
- rewrite mult_assoc in le13.
+ rewrite Nat.mul_assoc in le13.
assert (le14:c*base+base <= c*base*base).
{
replace (c*base+base) with ((c+1)*base).
- - apply mult_le_compat_r.
- rewrite mult_comm.
+ - apply Nat.mul_le_mono.
+ rewrite Nat.mul_comm.
destruct base.
+ lia.
+ simpl.
- apply plus_le_compat_l.
+ apply Nat.add_le_mono_l.
destruct n. lia.
destruct c. lia.
- apply lt_le_S.
- replace 0 with (S n *0) by auto.
- apply mult_lt_compat_l; lia.
- - rewrite mult_plus_distr_r.
- rewrite mult_1_l.
+ apply Nat.lt_succ_r.
+ auto with arith.
+ + reflexivity.
+ - rewrite Nat.mul_add_distr_r.
+ rewrite Nat.mul_1_l.
trivial.
}
assert (le15:(c * base + base) * base ^ n1 <= (c * base * base) * base ^ n1).
{
- apply mult_le_compat_r.
+ apply Nat.mul_le_mono_r.
auto.
}
- eapply lt_le_trans in le15; try eapply le13.
+ eapply Nat.lt_le_trans in le15; try eapply le13.
assert (le16:c * base * base ^ n1 * base <= c * base * base ^ n1 * base ^ (n2 - n1)).
{
- apply mult_le_compat_l.
+ apply Nat.mul_le_mono_l.
generalize (Nat.sub_gt _ _ ltn).
destruct (n2-n1).
- congruence.
- simpl; intros _ .
replace base with (base*base^0) at 1.
- + apply mult_le_compat_l.
+ + apply Nat.mul_le_mono_l.
apply Nat.pow_le_mono_r; lia.
+ simpl.
- rewrite mult_1_r.
+ rewrite Nat.mul_1_r.
trivial.
}
- eapply le_lt_trans in le16; try eapply le15.
- replace (c * base * base ^ n1 * base) with
- (c * base * base * base ^ n1) in le16.
- - intuition.
- - repeat rewrite mult_assoc_reverse.
- f_equal. f_equal.
- rewrite mult_comm.
- trivial.
+ lia.
Qed.
Lemma digits_to_nat_aux_acc_inj_helper12 a b c n1 n2 :
@@ -339,15 +326,14 @@ Section Digits.
; [ | eapply (digits_to_nat_aux_acc_inj_helper1 a b c n1 n2); eauto].
red in e; subst.
simpl in *.
- rewrite (le_plus_minus n1 n2) in lt1 by lia.
+ rewrite <- (Nat.sub_add n1 n2) in lt1 by lia.
rewrite Nat.pow_add_r in lt1.
- rewrite (mult_comm (base ^ n1)) in lt1.
- rewrite mult_assoc in lt1.
+ rewrite Nat.mul_assoc in lt1.
assert (le2:base*base^n1 <= a*base^(n2 - n1) * base ^ n1).
{
- apply mult_le_compat_r.
+ apply Nat.mul_le_mono_r.
replace base with (1*base) at 1 by lia.
- apply mult_le_compat.
+ apply Nat.mul_le_mono.
- replace 1 with (1*1) by lia.
simpl. lia.
- simpl.
@@ -355,13 +341,13 @@ Section Digits.
+ apply Nat.pow_le_mono_r; lia.
+ apply Nat.pow_1_r.
}
- eapply le_lt_trans in lt1; try eapply le2; clear le2.
+ eapply Nat.le_lt_trans in lt1; try eapply le2; clear le2.
assert (le3:(b + 1) * base ^ n1 <= base * base^n1).
{
- apply mult_le_compat_r.
+ apply Nat.mul_le_mono_r.
lia.
}
- eapply le_lt_trans in lt1; try eapply le3; clear le3.
+ eapply Nat.le_lt_trans in lt1; try eapply le3; clear le3.
lia.
Qed.
@@ -370,11 +356,10 @@ Section Digits.
~ b < a.
Proof.
intros lt1 l2.
- apply lt_not_le in lt1.
+ apply Nat.le_ngt in lt1.
apply lt1.
- apply mult_le_compat_r.
- rewrite plus_assoc_reverse.
- apply plus_le_compat_l.
+ apply Nat.lt_succ_r.
+ apply Nat.mul_le_mono_r.
lia.
Qed.
@@ -386,21 +371,21 @@ Section Digits.
~ n2 < n1.
Proof.
intros ? ? ? lt1 l2.
- apply lt_not_le in lt1.
+ apply Nat.le_ngt in lt1.
apply lt1.
- rewrite (le_plus_minus n2 n1) by lia.
+ apply Nat.lt_succ_r.
+ rewrite <- (Nat.sub_add n2 n1) by lia.
rewrite Nat.pow_add_r.
- rewrite (mult_comm a).
- rewrite (mult_comm (b+1)).
- rewrite <- mult_assoc.
- apply mult_le_compat_l.
+ rewrite Nat.mul_assoc.
+ apply Nat.mul_le_mono_r.
transitivity base; try lia.
transitivity (base^1*a).
- rewrite Nat.pow_1_r.
transitivity (base * 1); try lia.
- apply mult_le_compat_l.
+ apply Nat.mul_le_mono_l.
lia.
- - apply mult_le_compat_r.
+ - rewrite Nat.mul_comm.
+ apply Nat.mul_le_mono_l.
apply Nat.pow_le_mono_r; lia.
Qed.
@@ -414,8 +399,8 @@ Section Digits.
destruct (digits_to_nat_aux_bound l1 (c*base+a)) as [lb1 ub1].
destruct (digits_to_nat_aux_bound l2 (c*base+b)) as [lb2 ub2].
rewrite eqq1 in lb1,ub1.
- eapply le_lt_trans in lb1; [ | eapply ub2].
- eapply le_lt_trans in lb2; [ | eapply ub1].
+ eapply Nat.le_lt_trans in lb1; [ | eapply ub2].
+ eapply Nat.le_lt_trans in lb2; [ | eapply ub1].
clear eqq1 ub1 ub2.
revert lb1 lb2.
generalize (Datatypes.length l1).
@@ -448,8 +433,8 @@ Section Digits.
destruct (digits_to_nat_aux_bound l1 (c*base+a)) as [lb1 ub1].
destruct (digits_to_nat_aux_bound l2 (c*base+b)) as [lb2 ub2].
rewrite eqq1 in lb1,ub1.
- eapply le_lt_trans in lb1; [ | eapply ub2].
- eapply le_lt_trans in lb2; [ | eapply ub1].
+ eapply Nat.le_lt_trans in lb1; [ | eapply ub2].
+ eapply Nat.le_lt_trans in lb2; [ | eapply ub1].
clear eqq1 ub1 ub2.
revert lb1 lb2.
generalize (Datatypes.length l1).
@@ -485,15 +470,13 @@ Section Digits.
assert (le1:n * base <= n*1) by lia.
assert (le2:n * base <= n*1) by lia.
destruct n; [congruence|].
- apply mult_S_le_reg_l in le2.
- lia.
+ apply Nat.mul_le_mono_pos_l in le2; lia.
- generalize (digits_to_nat_aux_le l1 (n * base + proj1_sig a)); intros eqq.
rewrite H0 in eqq.
assert (le1:n * base <= n*1) by lia.
assert (le2:n * base <= n*1) by lia.
destruct n; [congruence|].
- apply mult_S_le_reg_l in le2.
- lia.
+ apply Nat.mul_le_mono_pos_l in le2; lia.
- assert (lt0:0
- match lt_dec n base with
+ match Compare_dec.lt_dec n base with
| left pf => Some (exist _ n pf)
| right _ => None
end
@@ -868,7 +848,7 @@ Section Digits.
Proof.
unfold char_to_digit.
simpl.
- destruct (lt_dec 0 base).
+ destruct (Compare_dec.lt_dec 0 base).
- f_equal.
apply digit_ext.
simpl; trivial.
@@ -934,7 +914,7 @@ Section Digits.
destruct (ascii_dec a "0"%char).
- subst.
unfold char_to_digit in eqq; simpl in eqq.
- destruct (lt_dec 0 base); [ | lia].
+ destruct (Compare_dec.lt_dec 0 base); [ | lia].
inversion eqq; clear eqq; subst.
case_eq (string_to_digits s)
; [intros ? eqq2 | intros eqq2]
@@ -1288,21 +1268,21 @@ End Digits.
Section Bases.
Definition lt_decider (a b:nat) :
- match lt_dec a b with
+ match Compare_dec.lt_dec a b with
| left pf => lt a b
| right _ => True
end.
Proof.
- destruct (lt_dec); trivial.
+ destruct (Compare_dec.lt_dec); trivial.
Defined.
Definition le_decider (a b:nat) :
- match le_dec a b with
+ match Compare_dec.le_dec a b with
| left pf => le a b
| right _ => True
end.
Proof.
- destruct (le_dec); trivial.
+ destruct (Compare_dec.le_dec); trivial.
Defined.
(** ** Base 2 *)
diff --git a/coq/lib_utils/LibUtilsFresh.v b/rocq/lib_utils/LibUtilsFresh.v
similarity index 99%
rename from coq/lib_utils/LibUtilsFresh.v
rename to rocq/lib_utils/LibUtilsFresh.v
index 754186d2..428a8dad 100644
--- a/coq/lib_utils/LibUtilsFresh.v
+++ b/rocq/lib_utils/LibUtilsFresh.v
@@ -20,7 +20,7 @@ as strings). *)
Require Import String.
Require Import List.
Require Import Permutation.
-Require Import Arith Min.
+Require Import Arith.
Require Import EquivDec.
Require Import Morphisms.
Require Import Lia.
@@ -121,7 +121,7 @@ Section Fresh.
Defined.
Definition find_fresh_inj_f {A:Type} {dec:EqDec A eq} f (inj:forall x y, f x = f y -> x = y) (dom:list A) : A
- := proj1_sig (find_bounded_S_nin_finds f dom (S (length dom)) (gt_Sn_n _) inj).
+ := proj1_sig (find_bounded_S_nin_finds f dom (S (length dom)) (Nat.lt_succ_diag_r _) inj).
Lemma find_fresh_inj_f_fresh {A:Type} {dec:EqDec A eq} f (inj:forall x y, f x = f y -> x = y) (dom:list A) :
~ In (find_fresh_inj_f f inj dom) dom.
diff --git a/coq/lib_utils/LibUtilsGroupByDomain.v b/rocq/lib_utils/LibUtilsGroupByDomain.v
similarity index 100%
rename from coq/lib_utils/LibUtilsGroupByDomain.v
rename to rocq/lib_utils/LibUtilsGroupByDomain.v
diff --git a/coq/lib_utils/LibUtilsInterleaved.v b/rocq/lib_utils/LibUtilsInterleaved.v
similarity index 100%
rename from coq/lib_utils/LibUtilsInterleaved.v
rename to rocq/lib_utils/LibUtilsInterleaved.v
diff --git a/coq/lib_utils/LibUtilsLattice.v b/rocq/lib_utils/LibUtilsLattice.v
similarity index 98%
rename from coq/lib_utils/LibUtilsLattice.v
rename to rocq/lib_utils/LibUtilsLattice.v
index 747300b4..902a9440 100644
--- a/coq/lib_utils/LibUtilsLattice.v
+++ b/rocq/lib_utils/LibUtilsLattice.v
@@ -54,8 +54,8 @@ Section Lattice.
meet : A -> A -> A;
join : A -> A -> A;
- meet_morphism :> Proper (eqA ==> eqA ==> eqA) meet ;
- join_morphism :> Proper (eqA ==> eqA ==> eqA) join ;
+ meet_morphism :: Proper (eqA ==> eqA ==> eqA) meet ;
+ join_morphism :: Proper (eqA ==> eqA ==> eqA) join ;
meet_commutative : commutative eqA meet;
meet_associative : associative eqA meet;
diff --git a/coq/lib_utils/LibUtilsLift.v b/rocq/lib_utils/LibUtilsLift.v
similarity index 100%
rename from coq/lib_utils/LibUtilsLift.v
rename to rocq/lib_utils/LibUtilsLift.v
diff --git a/coq/lib_utils/LibUtilsLiftIterators.v b/rocq/lib_utils/LibUtilsLiftIterators.v
similarity index 99%
rename from coq/lib_utils/LibUtilsLiftIterators.v
rename to rocq/lib_utils/LibUtilsLiftIterators.v
index 4446062d..d5043bce 100644
--- a/coq/lib_utils/LibUtilsLiftIterators.v
+++ b/rocq/lib_utils/LibUtilsLiftIterators.v
@@ -17,8 +17,6 @@
(** This module provides support for monadic iterators over bags. *)
Require Import Arith.
-Require Import Min.
-Require Import Max.
Require Import Lia.
Require Import Permutation.
Require Import Equivalence.
diff --git a/coq/lib_utils/LibUtilsListAdd.v b/rocq/lib_utils/LibUtilsListAdd.v
similarity index 100%
rename from coq/lib_utils/LibUtilsListAdd.v
rename to rocq/lib_utils/LibUtilsListAdd.v
diff --git a/coq/lib_utils/LibUtilsResult.v b/rocq/lib_utils/LibUtilsResult.v
similarity index 100%
rename from coq/lib_utils/LibUtilsResult.v
rename to rocq/lib_utils/LibUtilsResult.v
diff --git a/coq/lib_utils/LibUtilsSortingAdd.v b/rocq/lib_utils/LibUtilsSortingAdd.v
similarity index 100%
rename from coq/lib_utils/LibUtilsSortingAdd.v
rename to rocq/lib_utils/LibUtilsSortingAdd.v
diff --git a/coq/lib_utils/LibUtilsStringAdd.v b/rocq/lib_utils/LibUtilsStringAdd.v
similarity index 99%
rename from coq/lib_utils/LibUtilsStringAdd.v
rename to rocq/lib_utils/LibUtilsStringAdd.v
index 67d43375..3ea25391 100644
--- a/coq/lib_utils/LibUtilsStringAdd.v
+++ b/rocq/lib_utils/LibUtilsStringAdd.v
@@ -22,7 +22,6 @@ Require Import Ascii.
Require Import List.
Require Import String.
Require Import Arith.
-Require Import Min.
Require Import Equivalence.
Require Import EquivDec.
Require Import Compare_dec.
@@ -379,7 +378,7 @@ Section Prefix.
rewrite <- (substring_all l) at 3.
apply substring_le_prefix.
replace (String.length l - 0) with (String.length l) by lia.
- apply le_min_r.
+ apply Nat.le_min_r.
Qed.
Lemma in_of_append pre y l :
diff --git a/coq/lib_utils/LibUtilsSublist.v b/rocq/lib_utils/LibUtilsSublist.v
similarity index 100%
rename from coq/lib_utils/LibUtilsSublist.v
rename to rocq/lib_utils/LibUtilsSublist.v
diff --git a/coq/lib_utils/README.md b/rocq/lib_utils/README.md
similarity index 100%
rename from coq/lib_utils/README.md
rename to rocq/lib_utils/README.md
diff --git a/coq/utils/Assoc.v b/rocq/utils/Assoc.v
similarity index 100%
rename from coq/utils/Assoc.v
rename to rocq/utils/Assoc.v
diff --git a/coq/utils/BasicUtils.v b/rocq/utils/BasicUtils.v
similarity index 100%
rename from coq/utils/BasicUtils.v
rename to rocq/utils/BasicUtils.v
diff --git a/coq/utils/ClassicUtils.v b/rocq/utils/ClassicUtils.v
similarity index 100%
rename from coq/utils/ClassicUtils.v
rename to rocq/utils/ClassicUtils.v
diff --git a/coq/utils/CoquelicotAdd.v b/rocq/utils/CoquelicotAdd.v
similarity index 100%
rename from coq/utils/CoquelicotAdd.v
rename to rocq/utils/CoquelicotAdd.v
diff --git a/coq/utils/DVector.v b/rocq/utils/DVector.v
similarity index 98%
rename from coq/utils/DVector.v
rename to rocq/utils/DVector.v
index 1c98f433..628d0edf 100644
--- a/coq/utils/DVector.v
+++ b/rocq/utils/DVector.v
@@ -211,7 +211,7 @@ Proof.
apply In_nth_error in inn.
destruct inn as [i eqq].
destruct x; simpl in *.
- destruct (lt_dec i (length x)).
+ destruct (Compare_dec.lt_dec i (length x)).
- subst.
exists i, l.
unfold vector_nth, proj1_sig.
@@ -360,7 +360,7 @@ Lemma vector_list_create_shiftS
(start len:nat)
(f:(forall m, S start <= m -> m < S start + len -> T)%nat) :
vector_list_create (S start) len f =
- vector_list_create start len (fun x pf1 pf2 => f (S x)%nat (le_n_S _ _ pf1) (lt_n_S _ _ pf2)).
+ vector_list_create start len (fun x pf1 pf2 => f (S x)%nat (le_n_S _ _ pf1) (proj1 (Nat.succ_lt_mono _ _) pf2)).
Proof.
revert start f.
induction len; simpl; trivial; intros.
@@ -376,7 +376,7 @@ Lemma vector_list_create_shift0
(start len:nat)
(f:(forall m, start <= m -> m < start + len -> T)%nat) :
vector_list_create start len f =
- vector_list_create 0 len (fun x _ pf2 => f (start+x)%nat (le_plus_l start _) (plus_lt_compat_l _ _ start pf2)).
+ vector_list_create 0 len (fun x _ pf2 => f (start+x)%nat (Nat.le_add_r start _) (proj1 (Nat.add_lt_mono_l _ _ start) pf2)).
Proof.
induction start; simpl.
- apply vector_list_create_ext; intros.
@@ -448,7 +448,7 @@ Lemma vector_nth_create
(i : nat)
(pf2: i < len)
(f:(forall m, start <= m -> m < start + len -> T)%nat) :
- vector_nth i pf2 (vector_create start len f) = f (start + i) (PeanoNat.Nat.le_add_r start i) (Plus.plus_lt_compat_l _ _ start pf2).
+ vector_nth i pf2 (vector_create start len f) = f (start + i) (PeanoNat.Nat.le_add_r start i) (proj1 (Nat.add_lt_mono_l _ _ start) pf2).
Proof.
unfold vector_nth, proj1_sig.
repeat match_destr.
@@ -507,7 +507,7 @@ Program Definition vector_zip {A B:Type}
Next Obligation.
rewrite combine_length.
repeat rewrite vector_length.
- now rewrite Min.min_idempotent.
+ now rewrite Nat.min_idempotent.
Qed.
(* move this *)
@@ -907,7 +907,7 @@ Qed.
Lemma vector_nthS {A} a i (l:list A) pf1 pf2 :
(vector_nth (S i) pf1
(exist (fun l0 : list A => length l0 = S (length l)) (a :: l) pf2))
- = vector_nth i (lt_S_n _ _ pf1) (exist (fun l0 : list A => length l0 = length l) (l) (eq_add_S _ _ pf2)).
+ = vector_nth i (proj2 (Nat.succ_lt_mono _ _) pf1) (exist (fun l0 : list A => length l0 = length l) (l) (eq_add_S _ _ pf2)).
Proof.
unfold vector_nth, proj1_sig.
repeat match_destr.
@@ -1186,7 +1186,7 @@ Section ivector.
| 0%nat => fun pf _ => False_rect _ (Nat.nlt_0_r _ pf)
| S n' => match idx with
| 0%nat => fun pf '(hd,tl) => hd
- | S m' => fun pf '(hd,tl) => ivector_nth m' (lt_S_n _ _ pf) tl
+ | S m' => fun pf '(hd,tl) => ivector_nth m' (proj2 (Nat.succ_lt_mono _ _) pf) tl
end
end.
@@ -1456,7 +1456,7 @@ Section ivector.
+ specialize (H 0 ((Nat.lt_0_succ _))).
apply H.
+ apply IHn; intros.
- specialize (H (S i1) ( (lt_n_S _ _ pf))).
+ specialize (H (S i1) ( (proj1 (Nat.succ_lt_mono _ _) pf))).
simpl in H.
rewrite (ivector_nth_prf_irrelevance _ i _ _ pf) in H.
now rewrite (ivector_nth_prf_irrelevance _ i0 _ _ pf) in H.
@@ -1488,7 +1488,7 @@ Qed.
:= match n as n' return (forall i (pf:(i ivector B n' with
| 0%nat => fun _ => tt
| S m =>
- fun f => (f 0%nat (Nat.lt_0_succ _) , ivector_create m (fun i pf => f (S i) (lt_n_S _ _ pf)))
+ fun f => (f 0%nat (Nat.lt_0_succ _) , ivector_create m (fun i pf => f (S i) (proj1 (Nat.succ_lt_mono _ _) pf)))
end.
Lemma ivector_nth_from_list {A}
@@ -1576,7 +1576,7 @@ Qed.
Lemma vector_nth_cons_S {A} {n} a (vec : vector A n) i pf :
vector_nth (S i) pf (vector_cons a vec) =
- vector_nth i (lt_S_n i n pf) vec.
+ vector_nth i (proj2 (Nat.succ_lt_mono i n) pf) vec.
Proof.
unfold vector_cons, vector_nth, proj1_sig.
destruct vec.
@@ -1749,7 +1749,7 @@ Section Sequence.
Definition ivector_to_sequence {T} {n} (v : ivector T n) (default : T)
: nat -> T :=
(fun i =>
- match lt_dec i n with
+ match Compare_dec.lt_dec i n with
| left pf => ivector_nth i pf v
| right _ => default
end).
diff --git a/coq/utils/ELim_Seq.v b/rocq/utils/ELim_Seq.v
similarity index 99%
rename from coq/utils/ELim_Seq.v
rename to rocq/utils/ELim_Seq.v
index b18226be..6481ae45 100644
--- a/coq/utils/ELim_Seq.v
+++ b/rocq/utils/ELim_Seq.v
@@ -172,7 +172,7 @@ Proof.
- intros.
exists 0%nat; intros; lia.
- intros HH.
- specialize (IHN (fun x pf => P x (lt_S _ _ pf))).
+ specialize (IHN (fun x pf => P x (Nat.lt_lt_succ_r _ _ pf))).
cut_to IHN.
+ simpl in IHN.
specialize (HH N (Nat.lt_succ_diag_r _)).
diff --git a/coq/utils/FiniteType.v b/rocq/utils/FiniteType.v
similarity index 97%
rename from coq/utils/FiniteType.v
rename to rocq/utils/FiniteType.v
index 51f57628..4ae619a1 100644
--- a/coq/utils/FiniteType.v
+++ b/rocq/utils/FiniteType.v
@@ -1,5 +1,5 @@
Require Import ClassicalDescription.
-Require Import FinFun List EquivDec FunctionalExtensionality Lia Eqdep_dec.
+Require Import Arith FinFun List EquivDec FunctionalExtensionality Lia Eqdep_dec.
Require Import LibUtils.
Require Import Isomorphism ListAdd BasicUtils.
@@ -202,13 +202,12 @@ Lemma fold_right_add_const {A:Type} c (l:list A) :
Proof.
induction l; simpl; trivial.
rewrite IHl; simpl.
- rewrite NPeano.Nat.mul_succ_r.
- now rewrite NPeano.Nat.add_comm.
+ lia.
Qed.
Lemma fold_right_mult_const {A:Type} c (l:list A) :
fold_right Nat.mul 1
- (map (fun _ => c) l) = NPeano.Nat.pow c (length l).
+ (map (fun _ => c) l) = Nat.pow c (length l).
Proof.
induction l; simpl; trivial.
now rewrite IHl; simpl.
@@ -232,12 +231,12 @@ Proof.
(FiniteType_fun_dep_elems_aux fin_elms0 (fun (x0 : A) (_ : In_strong x0 fin_elms0) => fin_elms))))
by (intros; now rewrite map_length).
rewrite fold_right_add_const.
- rewrite NPeano.Nat.mul_comm.
+ rewrite Nat.mul_comm.
now rewrite <- IHfin_elms0.
Qed.
Lemma FiniteType_fun_size {A:Type} {dec:EqDec A eq} {B:Type} (finA:FiniteType A) (finB:FiniteType B)
- : length (@fin_elms _ (FiniteType_fun finA finB)) = NPeano.pow (length (@fin_elms _ finB)) (length (@fin_elms _ finA)).
+ : length (@fin_elms _ (FiniteType_fun finA finB)) = Nat.pow (length (@fin_elms _ finB)) (length (@fin_elms _ finA)).
Proof.
unfold FiniteType_fun.
rewrite FiniteType_fun_dep_size.
@@ -507,7 +506,7 @@ Qed.
intros index eqq.
match_destr_in eqq.
- invcs eqq.
- rewrite PeanoNat.Nat.sub_diag; simpl.
+ rewrite Nat.sub_diag; simpl.
congruence.
- specialize (IHl _ eqq).
generalize (find_index_aux_bounds eqq); intros le1.
@@ -523,7 +522,7 @@ Qed.
Proof.
intros HH.
specialize (find_index_aux_correct HH).
- now rewrite PeanoNat.Nat.sub_0_r.
+ now rewrite Nat.sub_0_r.
Qed.
Lemma find_index_aux_first {l:list A} {a} {index:nat} {n} :
@@ -537,7 +536,7 @@ Qed.
intros index eqq.
match_destr_in eqq.
- invcs eqq.
- rewrite PeanoNat.Nat.sub_diag; simpl.
+ rewrite Nat.sub_diag; simpl.
lia.
- specialize (IHl _ eqq).
intros [|]; simpl; intros eqq2.
@@ -768,10 +767,11 @@ Section countableType.
(find_index' x l = length l)%nat.
Proof.
generalize (find_index'_in x l); intros.
- destruct (Lt.le_lt_or_eq _ _ (find_index'_le x l))
- ; firstorder.
- - lia.
- - intros inn.
+ destruct (proj1 (Nat.lt_eq_cases (find_index' x l) (length l))).
+ - apply find_index'_le.
+ - firstorder; lia.
+ - firstorder.
+ intros inn.
apply H in inn.
lia.
Qed.
diff --git a/coq/utils/FiniteTypeVector.v b/rocq/utils/FiniteTypeVector.v
similarity index 100%
rename from coq/utils/FiniteTypeVector.v
rename to rocq/utils/FiniteTypeVector.v
diff --git a/coq/utils/Isomorphism.v b/rocq/utils/Isomorphism.v
similarity index 100%
rename from coq/utils/Isomorphism.v
rename to rocq/utils/Isomorphism.v
diff --git a/coq/utils/ListAdd.v b/rocq/utils/ListAdd.v
similarity index 99%
rename from coq/utils/ListAdd.v
rename to rocq/utils/ListAdd.v
index 50ccebea..7cac708b 100644
--- a/coq/utils/ListAdd.v
+++ b/rocq/utils/ListAdd.v
@@ -398,15 +398,15 @@ Section fp.
rewrite Forall_forall in H.
apply H.
apply nth_In.
- now apply lt_S_n.
+ auto with arith.
+ right.
rewrite Forall_forall in H.
symmetry.
apply H.
apply nth_In.
- now apply lt_S_n.
- + apply lt_S_n in n1bound.
- apply lt_S_n in n2bound.
+ auto with arith.
+ + apply Nat.succ_lt_mono in n1bound.
+ apply Nat.succ_lt_mono in n2bound.
destruct (IHFP _ _ n1bound n2bound) as [?|?]; auto.
Qed.
@@ -2036,7 +2036,7 @@ Section cross_product.
apply Forall_impl; intros.
lia.
+
- rewrite <- Plus.plus_Snm_nSm.
+ replace (n + S (length l0)) with (S (n + length l0)) by auto with arith.
specialize (IHl0 (map (fun '(a0, b0) => b0 ++ [a0]) (list_prod a acc)) (S n)).
apply IHl0.
rewrite Forall_map; simpl.
@@ -2212,12 +2212,12 @@ Qed.
Lemma list_max_in l : l <> nil -> In (list_max l) l.
Proof.
induction l; simpl; [eauto |]; intros _.
- destruct (Max.max_dec a (list_max l))
+ destruct (Nat.max_dec a (list_max l))
; rewrite e in *
; eauto.
destruct l.
- simpl in e.
- rewrite Max.max_0_r in e; simpl
+ rewrite Nat.max_0_r in e; simpl
; eauto.
- right; apply IHl; congruence.
Qed.
diff --git a/coq/utils/NumberIso.v b/rocq/utils/NumberIso.v
similarity index 99%
rename from coq/utils/NumberIso.v
rename to rocq/utils/NumberIso.v
index caad723b..e56c1d41 100644
--- a/coq/utils/NumberIso.v
+++ b/rocq/utils/NumberIso.v
@@ -1,4 +1,4 @@
-Require Import BinNums BinNat Nat List.
+Require Import ZArith BinNums BinNat Nat List.
Require Import Lia.
Require Import LibUtils Isomorphism PairEncoding.
diff --git a/coq/utils/PairEncoding.v b/rocq/utils/PairEncoding.v
similarity index 99%
rename from coq/utils/PairEncoding.v
rename to rocq/utils/PairEncoding.v
index f6fcf493..4acab6ea 100644
--- a/coq/utils/PairEncoding.v
+++ b/rocq/utils/PairEncoding.v
@@ -1,4 +1,4 @@
-Require Import BinNums Nat List BinInt.
+Require Import Arith BinNums Nat List BinInt.
Require Import Lia.
Require Import LibUtils Isomorphism ListAdd.
@@ -1255,7 +1255,7 @@ Global Instance nat_pair_encoder : Isomorphism (nat*nat) nat := pair_encoder nat
now subst.
- destruct (Compare_dec.le_dec c1 c).
+ specialize (IHc l).
- eapply Le.le_trans.
+ eapply Nat.le_trans.
apply IHc.
rewrite seq_S with (len := (S c)).
rewrite map_app, list_max_app.
diff --git a/coq/utils/PushNeg.v b/rocq/utils/PushNeg.v
similarity index 100%
rename from coq/utils/PushNeg.v
rename to rocq/utils/PushNeg.v
diff --git a/coq/utils/Quotient.v b/rocq/utils/Quotient.v
similarity index 100%
rename from coq/utils/Quotient.v
rename to rocq/utils/Quotient.v
diff --git a/coq/utils/RbarAdd.v b/rocq/utils/RbarAdd.v
similarity index 97%
rename from coq/utils/RbarAdd.v
rename to rocq/utils/RbarAdd.v
index 201baa26..5c33c7b2 100644
--- a/coq/utils/RbarAdd.v
+++ b/rocq/utils/RbarAdd.v
@@ -218,7 +218,7 @@ Section sums.
unfold sum_Rbar_n; intros.
induction n; [simpl; f_equal; lra | ].
rewrite seq_Sn.
- rewrite plus_0_l.
+ rewrite Nat.add_0_l.
repeat rewrite map_app.
repeat rewrite list_Rbar_sum_nneg_plus; simpl
@@ -895,8 +895,8 @@ Qed.
+ now rewrite iso_b_f.
+ apply in_seq.
split; [lia |].
- rewrite plus_0_l.
- apply le_lt_n_Sm.
+ rewrite Nat.add_0_l.
+ apply Nat.lt_succ_r.
destruct H1.
apply in_seq in H1.
apply in_seq in H2.
@@ -1616,7 +1616,7 @@ Proof.
split.
+ now apply classic_min_of_some in H0.
+ intros.
- apply NPeano.Nat.nlt_ge; intros nlt.
+ apply Nat.nlt_ge; intros nlt.
eapply classic_min_of_some_first in H0; try apply nlt.
tauto.
- intros.
@@ -1624,7 +1624,7 @@ Proof.
apply zerotails in H.
apply is_lim_seq_spec in H.
simpl in H.
- elimtype False.
+ exfalso.
destruct (H ϵ) as [N ?].
elim (H1 N).
red; intros.
@@ -1752,10 +1752,10 @@ Proof.
Qed.
Definition zerotails_incr_mult (a : nat -> R) (pf:ex_series a) n : R
- := Series (fun n0 : nat => if le_dec (S (zerotails_eps2k_fun a pf n0)) n then 1 else 0).
+ := Series (fun n0 : nat => if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf n0)) n then 1 else 0).
Lemma zerotails_incr_mult_ex (a : nat -> R) (pf:ex_series a) n :
- ex_series (fun n0 : nat => if le_dec (S (zerotails_eps2k_fun a pf n0)) n then 1 else 0).
+ ex_series (fun n0 : nat => if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf n0)) n then 1 else 0).
Proof.
apply (ex_series_incr_n _ n).
apply (ex_series_ext (fun _ => 0)).
@@ -1769,7 +1769,7 @@ Proof.
Qed.
Lemma zerotails_incr_mult_trunc (a : nat -> R) (pf:ex_series a) n :
- zerotails_incr_mult a pf n = sum_n (fun n0 : nat => if le_dec (S (zerotails_eps2k_fun a pf n0)) n then 1 else 0) n.
+ zerotails_incr_mult a pf n = sum_n (fun n0 : nat => if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf n0)) n then 1 else 0) n.
Proof.
unfold zerotails_incr_mult.
apply is_series_unique.
@@ -1777,11 +1777,12 @@ Proof.
apply is_lim_seq_spec.
simpl; intros.
exists n; intros.
- generalize (sum_n_m_sum_n (fun n1 : nat => if le_dec (S (zerotails_eps2k_fun a pf n1)) n then 1 else 0) n n0).
+ generalize (sum_n_m_sum_n (fun n1 : nat => if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf n1)) n then 1 else 0) n n0 H).
match goal with
[|- context [minus ?x ?y]] => replace (minus x y) with (x - y) by reflexivity
end; simpl.
- intros HH; rewrite <- HH; trivial.
+ intros HH.
+ eapply Rle_lt_trans; [right; apply (symmetry (f_equal Rabs HH)) |].
erewrite (sum_n_m_ext_loc _ (fun _ => zero)).
- rewrite sum_n_m_const_zero.
unfold zero; simpl.
@@ -1826,7 +1827,7 @@ Proof.
assert (0 <= sum_f_R0'
(fun x : nat =>
if
- le_dec (S (zerotails_eps2k_fun a pf (S (x + S n))))
+ Compare_dec.le_dec (S (zerotails_eps2k_fun a pf (S (x + S n))))
(n + S (zerotails_eps2k_fun a pf m))
then 1
else 0) (zerotails_eps2k_fun a pf m)
@@ -1941,13 +1942,13 @@ Proof.
transitivity (
Series (fun k : nat => Series (fun n : nat =>
a n *
- if le_dec (S (zerotails_eps2k_fun a pf k)) n
+ if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf k)) n
then 1 else 0))).
{
apply Series_ext; intros.
rewrite (Series_incr_n_aux
(fun n0 : nat =>
- a n0 * (if le_dec (S (zerotails_eps2k_fun a pf n))%nat n0 then 1 else 0))
+ a n0 * (if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf n))%nat n0 then 1 else 0))
(S (zerotails_eps2k_fun a pf n))).
- apply Series_ext; intros.
match_destr.
@@ -1965,7 +1966,7 @@ Proof.
(fun n : nat =>
Series
(fun k : nat =>
- a n * (if le_dec (S (zerotails_eps2k_fun a pf k)) n then 1 else 0)))).
+ a n * (if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf k)) n then 1 else 0)))).
{
apply Series_nneg_nested_swap.
- intros.
@@ -1983,7 +1984,7 @@ Proof.
unfold pointwise_relation; simpl; intros.
rewrite <- ELim_seq_incr_1.
rewrite ELim_seq_ext with
- (v := (fun n => Finite (sum_n (fun j : nat => a j * (if le_dec (S (zerotails_eps2k_fun a pf a0)) j then 1 else 0)) n))).
+ (v := (fun n => Finite (sum_n (fun j : nat => a j * (if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf a0)) j then 1 else 0)) n))).
rewrite Elim_seq_fin.
rewrite <- ELim_seq_incr_1.
rewrite ELim_seq_ext with
@@ -1993,7 +1994,7 @@ Proof.
rewrite ex_series_Lim_seq.
-- rewrite (Series_incr_n_aux
(fun n0 : nat =>
- a n0 * (if le_dec (S (zerotails_eps2k_fun a pf a0))%nat n0 then 1 else 0))
+ a n0 * (if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf a0))%nat n0 then 1 else 0))
(S (zerotails_eps2k_fun a pf a0))).
++ apply Rbar_finite_eq.
apply Series_ext; intros.
@@ -2011,7 +2012,7 @@ Proof.
(fun x => a ((S (zerotails_eps2k_fun a pf a0)) + x)%nat)).
++ intros; f_equal; lia.
++ now apply ex_series_incr_n.
- -- apply (ex_series_le (fun j : nat => a j * (if le_dec (S (zerotails_eps2k_fun a pf a0)) j then 1 else 0)) a); trivial.
+ -- apply (ex_series_le (fun j : nat => a j * (if Compare_dec.le_dec (S (zerotails_eps2k_fun a pf a0)) j then 1 else 0)) a); trivial.
intros.
unfold norm; simpl.
unfold abs; simpl.
@@ -2068,7 +2069,7 @@ Qed.
- intros.
simpl.
destruct H2.
- destruct (le_dec x n).
+ destruct (Compare_dec.le_dec x n).
+ now apply H2.
+ specialize (H2 x).
cut_to H2; try lia; simpl in H2.
@@ -2085,7 +2086,7 @@ Qed.
split; try lia.
assert (Rbar_le (f x) (f (Init.Nat.max x N))).
{
- destruct (le_dec N x).
+ destruct (Compare_dec.le_dec N x).
- rewrite Nat.max_l; try lia.
apply Rbar_le_refl.
- rewrite Nat.max_r; try lia.
@@ -2116,7 +2117,7 @@ Qed.
split; try lia.
assert (Rbar_le (f x) (f (Init.Nat.max x N))).
{
- destruct (le_dec N x).
+ destruct (Compare_dec.le_dec N x).
- rewrite Nat.max_l; try lia.
apply Rbar_le_refl.
- rewrite Nat.max_r; try lia.
@@ -2129,7 +2130,7 @@ Qed.
- unfold is_ELimSup_seq, is_sup_seq.
split; intros.
+ destruct (H0 M).
- destruct (le_dec x n).
+ destruct (Compare_dec.le_dec x n).
* now apply H1.
* assert (n <= x)%nat by lia.
apply Rbar_le_lt_trans with (y := f x).
diff --git a/coq/utils/RealAdd.v b/rocq/utils/RealAdd.v
similarity index 99%
rename from coq/utils/RealAdd.v
rename to rocq/utils/RealAdd.v
index 4ee9811e..9c237488 100644
--- a/coq/utils/RealAdd.v
+++ b/rocq/utils/RealAdd.v
@@ -6,7 +6,7 @@ Require Import Coq.Reals.Rfunctions.
Require Import Coq.Reals.Rprod Coq.Reals.ROrderedType.
Require Import Ranalysis_reg.
Require Import Coquelicot.Coquelicot.
-Require Import Lia Lra.
+Require Import ZArith Lia Lra.
Require Import Reals.Integration.
Require Import Coq.Reals.SeqProp.
Require Import Rtrigo_def.
@@ -1615,7 +1615,7 @@ Proof.
unfold Rpower.
unfold ln.
match_destr; try lra.
- - elimtype False; try tauto; lra.
+ - exfalso; try tauto; lra.
- rewrite Rmult_0_r.
now rewrite exp_0.
Qed.
@@ -1641,7 +1641,7 @@ Lemma Rpower_base_0 e : Rpower 0 e = 1.
Proof.
unfold Rpower, ln.
match_destr.
- - elimtype False; try tauto; lra.
+ - exfalso; try tauto; lra.
- rewrite Rmult_0_r.
now rewrite exp_0.
Qed.
@@ -2762,7 +2762,7 @@ Section lim_seq_sup_seq.
simpl.
split; intros.
+ destruct H1.
- destruct (le_dec x n).
+ destruct (Compare_dec.le_dec x n).
* now apply H1.
* assert (n <= x)%nat by lia.
apply Rle_lt_trans with (r2 := f x).
@@ -2782,7 +2782,7 @@ Section lim_seq_sup_seq.
unfold is_sup_seq; simpl; intros.
specialize (H0 M).
destruct H0 as [N H0].
- destruct (le_dec N n).
+ destruct (Compare_dec.le_dec N n).
+ now apply H0.
+ assert (n <= N)%nat by lia.
apply Rle_lt_trans with (r2 := f N).
@@ -3371,7 +3371,7 @@ Section Rmax_list.
(ex:exists x, In x l /\ P x) : {x | In x l /\ P x}.
Proof.
induction l; simpl.
- - elimtype False.
+ - exfalso.
destruct ex ; intuition.
- destruct (dec a).
+ exists a ; eauto.
@@ -4741,7 +4741,7 @@ Proof.
destruct (Hx posreal_one).
destruct (fin_seq_bounded x x0).
exists (Rmax x1 ((Rabs c)+1)); intros.
- destruct (lt_dec n x0).
+ destruct (Compare_dec.lt_dec n x0).
- eapply Rle_trans.
+ apply H0; lia.
+ apply Rmax_l.
@@ -4905,7 +4905,7 @@ Qed.
Lemma sum_n_sum_f_clipped (f : nat -> R) (N : nat) :
forall (n:nat),
(n >= N)%nat ->
- sum_n f N = sum_n (fun j => if (le_dec j N) then (f j) else 0) n.
+ sum_n f N = sum_n (fun j => if (Compare_dec.le_dec j N) then (f j) else 0) n.
Proof.
intros.
replace (n) with (N + (n - N))%nat by lia.
@@ -6652,7 +6652,7 @@ Section powerRZ.
unfold proj1_sig.
match_destr; match_destr.
destruct (Z_le_gt_dec x x0); trivial.
- elimtype False.
+ exfalso.
assert (x0 <= x - 1)%Z by lia.
assert (powerRZ base x0 <= powerRZ base (x-1)%Z).
{
@@ -6820,7 +6820,7 @@ Qed.
Proof.
induction l; try now simpl.
simpl.
- destruct (lt_dec 0 (length l)).
+ destruct (Compare_dec.lt_dec 0 (length l)).
+ rewrite prod_f_R0_succ; try assumption.
rewrite IHl. rewrite Rmult_assoc.
f_equal. rewrite prod_f_R0_pred; try assumption.
@@ -6894,7 +6894,7 @@ Qed.
prod_f_R0 f n2 = 0.
Proof.
intros.
- destruct (lt_dec n1 n2).
+ destruct (Compare_dec.lt_dec n1 n2).
- rewrite prod_SO_split with (k := n1) (n := n2); trivial.
rewrite prod_f_R0_n; trivial.
apply Rmult_0_l.
diff --git a/coq/utils/RiemannAdd.v b/rocq/utils/RiemannAdd.v
similarity index 99%
rename from coq/utils/RiemannAdd.v
rename to rocq/utils/RiemannAdd.v
index f4c40523..bf88c61a 100644
--- a/coq/utils/RiemannAdd.v
+++ b/rocq/utils/RiemannAdd.v
@@ -1,4 +1,4 @@
-Require Import Reals.Rbase Coq.Reals.RList.
+Require Import ZArith Reals.Rbase Coq.Reals.RList.
Require Import Reals.Rfunctions.
Require Import Ranalysis_reg.
Require Import Reals.Integration.
diff --git a/coq/utils/StreamAdd.v b/rocq/utils/StreamAdd.v
similarity index 98%
rename from coq/utils/StreamAdd.v
rename to rocq/utils/StreamAdd.v
index 68bdd213..a3df94c7 100644
--- a/coq/utils/StreamAdd.v
+++ b/rocq/utils/StreamAdd.v
@@ -135,7 +135,7 @@ Section Cutting.
List.firstn i (firstn j s) = List.firstn j (firstn i s).
Proof.
repeat rewrite firstn_firstn.
- rewrite Min.min_comm.
+ rewrite PeanoNat.Nat.min_comm.
trivial.
Qed.
diff --git a/coq/utils/StreamLimits.v b/rocq/utils/StreamLimits.v
similarity index 100%
rename from coq/utils/StreamLimits.v
rename to rocq/utils/StreamLimits.v
diff --git a/coq/utils/Sums.v b/rocq/utils/Sums.v
similarity index 98%
rename from coq/utils/Sums.v
rename to rocq/utils/Sums.v
index d50af400..0b2b3e70 100644
--- a/coq/utils/Sums.v
+++ b/rocq/utils/Sums.v
@@ -558,9 +558,9 @@ Section inf_sum'.
intros n ngt.
specialize (H1 n).
- cut_to H1; [ | apply (le_trans _ (max N1 N2)); auto with arith].
+ cut_to H1; [ | apply (Nat.le_trans _ (max N1 N2)); auto with arith].
specialize (H2 n).
- cut_to H2; [ | apply (le_trans _ (max N1 N2)); auto with arith].
+ cut_to H2; [ | apply (Nat.le_trans _ (max N1 N2)); auto with arith].
rewrite sum_f_R0'_plus.
generalize (R_dist_plus (sum_f_R0' f1 n) sum1 (sum_f_R0' f2 n) sum2); intros.
@@ -685,8 +685,8 @@ Section harmonic.
intros neq.
induction b; simpl.
- lia.
- - apply gt_n_S in IHb.
- eapply le_gt_trans; try eassumption.
+ - apply Nat.succ_lt_mono in IHb.
+ eapply Nat.lt_le_trans; try eassumption.
apply Sle_mult_gt1; lia.
Qed.
Lemma sum_f_R0'_eq2 n :
@@ -763,10 +763,10 @@ Section harmonic.
lra.
- rewrite Nat.pow_mul_r.
simpl.
- rewrite NPeano.Nat.pow_add_r.
+ rewrite Nat.pow_add_r.
unfold ge.
replace N with (N * 1)%nat at 1 by lia.
- apply mult_le_compat.
+ apply Nat.mul_le_mono.
+ generalize (pow_exp_gt 4 N)
; lia.
+ generalize (Z.to_nat (up l)); intros n.
@@ -1004,7 +1004,7 @@ Proof.
apply in_seq.
split; [lia|].
simpl.
- eapply lt_le_trans; try apply mbig.
+ eapply Nat.lt_le_trans; try apply mbig.
unfold M.
unfold lt.
generalize (list_max_upper (map ginv (seq 0 N))); intros FF.
@@ -1034,7 +1034,7 @@ Proof.
assert(l2_lower: (forall x, In x l2 -> x >= N)%nat).
{
intros.
- destruct (ge_dec x N); trivial.
+ destruct (Compare_dec.ge_dec x N); trivial.
apply Compare_dec.not_ge in n.
assert (inn:In x gpre).
{
@@ -1047,7 +1047,7 @@ Proof.
}
pose (nn:=List.list_max l2).
destruct (list_max_le l2 nn) as [l2_upper _].
- specialize (l2_upper (le_refl _)).
+ specialize (l2_upper (Nat.le_refl _)).
assert (incl1:incl l2 (seq N (S nn-N))).
{
intros ? inn.
@@ -1099,7 +1099,7 @@ Proof.
rewrite map_map.
specialize (N2_lt nn (max N1 N2))%nat.
- cut_to N2_lt.
+ cut_to N2_lt; try lia.
- unfold R_dist in N2_lt.
repeat rewrite sum_f_R0_sum_f_R0' in N2_lt.
repeat rewrite sum_f_R0'_list_sum in N2_lt.
@@ -1115,11 +1115,7 @@ Proof.
apply in_map_iff in H.
destruct H as [?[??]]; subst.
apply Rabs_pos.
- + apply le_plus_minus_r.
- lia.
- - red.
- transitivity N; lia.
- - lia.
+ + lia.
Qed.
Corollary infinite_sum'_pos_perm (g:nat->nat) (f:nat -> R) l:
@@ -1247,7 +1243,7 @@ Qed.
Lemma sum_n_m_shift (α : nat -> R) (k n0 : nat) :
- sum_n_m α k (n0 + k)%nat = sum_n (fun n1 : nat => α (n1 + k)%nat) n0.
+ @sum_n_m R_AbelianGroup α k (n0 + k)%nat = @sum_n R_AbelianGroup (fun n1 : nat => α (n1 + k)%nat) n0.
Proof.
unfold sum_n.
induction n0.
@@ -1263,7 +1259,7 @@ Qed.
Lemma sum_n_m_pos a n1 n2 :
(forall n, (n1 <= n <= n2)%nat -> 0 <= a n) ->
- 0 <= (sum_n_m a n1 n2).
+ 0 <= (@sum_n_m R_AbelianGroup a n1 n2).
Proof.
intros.
rewrite sum_n_m_fold_right_seq.
@@ -1280,7 +1276,7 @@ Qed.
Lemma sum_n_pos_incr a n1 n2 : (forall n, (n1 < n <= n2)%nat -> 0 <= a n) ->
- (n1 <= n2)%nat -> sum_n a n1 <= sum_n a n2.
+ (n1 <= n2)%nat -> @sum_n R_AbelianGroup a n1 <= @sum_n R_AbelianGroup a n2.
Proof.
intros.
destruct (Nat.eq_dec n1 n2); [rewrite e; lra|].
@@ -1349,7 +1345,7 @@ Section Sequences.
destruct H.
exists (Rmax ((Rabs x)+1) (Rmax_list (map (fun n => Rabs (f n)) (seq 0 x0)))).
intros.
- destruct (dec_le x0 n).
+ destruct (Compare_dec.dec_le x0 n).
- specialize (H n H1).
apply Rle_trans with (r2 := Rabs x + 1).
+ simpl in H.
@@ -1862,7 +1858,7 @@ Qed.
rewrite <-H2. simpl.
do 2 rewrite <-sum_n_Reals.
replace n with (S (pred n)) by lia.
- rewrite sum_n_m_sum_n; try lia.
+ rewrite (@sum_n_m_sum_n R_AbelianGroup); try lia.
reflexivity.
Qed.
@@ -1882,7 +1878,7 @@ Qed.
rewrite <-H2. simpl.
do 2 rewrite <-sum_n_Reals.
replace n with (S (pred n)) by lia.
- rewrite sum_n_m_sum_n; try lia.
+ rewrite (@sum_n_m_sum_n R_AbelianGroup); try lia.
reflexivity.
Qed.
@@ -1894,7 +1890,7 @@ Qed.
Series (fun i => f (S m + i)%nat).
Proof.
intros.
- destruct (lt_dec 0 n).
+ destruct (Compare_dec.lt_dec 0 n).
- apply sum_n_m_Series1; trivial; lia.
- assert (n=0)%nat by lia.
setoid_rewrite H1.
@@ -1929,7 +1925,7 @@ Qed.
apply Rmax_list_map_seq_lt_gen; try lia.
intros.
specialize (H0 (N+k)%nat n).
- destruct (le_dec (N + k)%nat n).
+ destruct (Compare_dec.le_dec (N + k)%nat n).
- apply H0; lia.
- assert (n < N + k)%nat by lia.
rewrite sum_n_m_zero; try lia.
@@ -1957,7 +1953,7 @@ Qed.
apply Rmax_list_map_seq_lt_gen; try lia.
intros.
specialize (H0 (S (N + k)) (n - 1)%nat).
- destruct (lt_dec (n-1)%nat (S (N + k))).
+ destruct (Compare_dec.lt_dec (n-1)%nat (S (N + k))).
- rewrite sum_n_m_zero; try lia.
unfold zero; simpl.
cbn.
@@ -2306,7 +2302,7 @@ Section tails.
{
intros.
unfold s.
- destruct (lt_dec n m).
+ destruct (Compare_dec.lt_dec n m).
- unfold sum_n.
apply Rge_le.
rewrite (sum_n_m_Chasles _ _ n); try lia.
@@ -2339,7 +2335,7 @@ Section tails.
- unfold Rdiv.
rewrite sum_n_m_ext with (b := (fun n : nat => scal (/ s (S N + k)%nat) (gamma n))).
+ rewrite sum_n_m_scal_l.
- rewrite sum_n_m_sum_n; try lia.
+ rewrite (@sum_n_m_sum_n R_AbelianGroup); try lia.
unfold s.
unfold scal; simpl.
unfold mult; simpl.
@@ -2387,7 +2383,8 @@ Section tails.
rewrite sum_n_m_ext with (b := fun n => / s n * gamma n) by (intros; unfold Rdiv; now rewrite Rmult_comm).
generalize (sum_n_m_sum_n (fun n : nat => / s n * gamma n) N (S N + k)); intros.
cut_to H5; try lia.
- rewrite H5.
+ eapply Rle_lt_trans; [right; apply H5 |].
+
unfold minus; simpl.
unfold plus, opp; simpl.
rewrite Rabs_minus_sym in H3.
@@ -2396,7 +2393,7 @@ Section tails.
unfold minus, plus, opp in H5; simpl in H5.
unfold Rminus.
replace (S N + k)%nat with (S (N + k))%nat by lia.
- rewrite <- H5.
+ etransitivity; [right; symmetry; apply H5 |].
apply Rle_ge.
apply sum_n_m_pos.
intros.
@@ -2457,7 +2454,7 @@ Section tails.
sum_n (fun n0 : nat => X (n0 + S m)%nat) a.
Proof.
rewrite <-sum_n_m_shift.
- rewrite sum_n_m_sum_n; try lia.
+ rewrite (@sum_n_m_sum_n R_AbelianGroup); try lia.
reflexivity.
Qed.
@@ -2517,7 +2514,7 @@ Section tails.
cut_to H2; trivial.
- destruct H2 as [rho [? ?]].
assert (0 < 1) by lra.
- exists (fun n => if (lt_dec n N) then (mkposreal _ H4) else rho (n - N)%nat).
+ exists (fun n => if (Compare_dec.lt_dec n N) then (mkposreal _ H4) else rho (n - N)%nat).
split.
+ apply is_lim_seq_incr_n with (N := N).
apply is_lim_seq_ext with (u := rho); trivial.
diff --git a/coq/utils/Utils.v b/rocq/utils/Utils.v
similarity index 100%
rename from coq/utils/Utils.v
rename to rocq/utils/Utils.v
diff --git a/coq/utils/Vector.v b/rocq/utils/Vector.v
similarity index 96%
rename from coq/utils/Vector.v
rename to rocq/utils/Vector.v
index 39b2027d..03aa9b70 100644
--- a/coq/utils/Vector.v
+++ b/rocq/utils/Vector.v
@@ -1,4 +1,4 @@
-Require Import List Lia.
+Require Import Arith List Lia.
Require Import LibUtils ListAdd BasicUtils.
Section Vector.
@@ -38,7 +38,7 @@ Section Vector.
init
| S bound1 =>
fun pf0 : S bound1 <= m =>
- let an := vector_fold_right1_bounded_dep f init singleton v bound1 (Le.le_Sn_le bound1 m pf0) in
+ let an := vector_fold_right1_bounded_dep f init singleton v bound1 (Nat.lt_le_incl bound1 m pf0) in
match bound1 as bound1' return (A bound1' -> S bound1' <= m -> A (S bound1')) with
| 0 => fun (_ : A 0) (pf1 : 1 <= m) =>
@@ -58,7 +58,7 @@ Section Vector.
- apply f.
+ exact (v (exist _ n pf)).
+ apply IHn.
- exact (Le.le_Sn_le _ _ pf).
+ exact (Nat.lt_le_incl _ _ pf).
Defined.
Definition vnil {T} : Vector T 0.
@@ -74,19 +74,19 @@ Section Vector.
+ exact x.
+ apply v.
exists i.
- apply NPeano.Nat.le_neq.
+ apply Nat.le_neq.
split; trivial.
now apply le_S_n in pf.
Defined.
- Definition vhd {T} {n} (v:Vector T (S n)) : T := v (exist _ (0%nat) (NPeano.Nat.lt_0_succ n)).
- Definition vlast {T} {n} (v:Vector T (S n)) : T := v (exist _ (n%nat) (NPeano.Nat.lt_succ_diag_r n)).
+ Definition vhd {T} {n} (v:Vector T (S n)) : T := v (exist _ (0%nat) (Nat.lt_0_succ n)).
+ Definition vlast {T} {n} (v:Vector T (S n)) : T := v (exist _ (n%nat) (Nat.lt_succ_diag_r n)).
Definition vdrop_last {T} {n} (v:Vector T (S n)) : Vector T n.
Proof.
intros [i pf]; apply v.
exists i.
- apply NPeano.Nat.lt_lt_succ_r; trivial.
+ apply Nat.lt_lt_succ_r; trivial.
Defined.
@@ -120,11 +120,11 @@ Section Vector.
Definition vector_fold_right1_dep {A:nat->Type} {B} (f:forall n, B->A n->A (S n))
(init:A 0%nat) (singleton:B->A 1%nat) {m:nat} (v:Vector B m) : A m
- := vector_fold_right1_bounded_dep f init singleton v m (Le.le_refl _).
+ := vector_fold_right1_bounded_dep f init singleton v m (Nat.le_refl _).
Definition vector_fold_right_dep {A:nat->Type} {B} (f:forall n, B->A n->A (S n))
(init:A 0%nat) {m:nat} (v:Vector B m) : A m
- := vector_fold_right_bounded_dep f init v m (Le.le_refl _).
+ := vector_fold_right_bounded_dep f init v m (Nat.le_refl _).
Definition vector_fold_right1 {A B:Type} (f:B->A->A) (init:A) (singleton:B->A) {m:nat} (v:Vector B m)
:= vector_fold_right1_dep (A:=fun _ => A) (fun _ => f) init singleton v.
@@ -210,7 +210,7 @@ Section Vector.
Definition list_fold_right1_dep {A:nat->Type} {B} (f:forall n, B->A n->A (S n))
(init:A 0%nat) (singleton:B->A 1%nat) (l:list B) : A (length l)
- := list_fold_right1_bounded_dep f init singleton l (length l) (Le.le_refl _).
+ := list_fold_right1_bounded_dep f init singleton l (length l) (Nat.le_refl _).
Definition list_fold_right_dep {A:nat->Type} {B} (f:forall n, B->A n->A (S n))
(init:A 0%nat) (l:list B) : A (length l)
@@ -274,7 +274,7 @@ Section Vector.
Proof.
intros [i pf].
unfold vcons, vlast, vdrop_last.
- destruct (NPeano.Nat.eq_dec i n)
+ destruct (Nat.eq_dec i n)
; subst
; f_equal
; apply index_pf_irrel.
@@ -380,7 +380,7 @@ Section Vector.
intros; subst.
intros [i pf].
unfold vcons.
- destruct (NPeano.Nat.eq_dec i n); simpl; trivial.
+ destruct (Nat.eq_dec i n); simpl; trivial.
Qed.
Lemma vdrop_last_proper {T} {n} (x y:Vector T (S n)) : x =v= y -> vdrop_last x =v= vdrop_last y.
@@ -417,7 +417,7 @@ Section Vector.
rewrite (vector_fold_right_dep_ext _ _ (vector_Sn_split v)).
unfold vector_fold_right_dep.
simpl.
- destruct (NPeano.Nat.eq_dec m m) ; [ | congruence].
+ destruct (Nat.eq_dec m m) ; [ | congruence].
f_equal.
erewrite vector_fold_right_dep_bounded_pf_ext.
erewrite vector_fold_right_dep_bounded_cut_down.
@@ -514,7 +514,7 @@ Section Vector.
Lemma vector_fold_right1_dep_1 {A:nat->Type} {B} (f:forall n,B->A n->A (S n))
(init:A 0%nat) sing (v:Vector B 1) :
- vector_fold_right1_dep f init sing v = sing (v (exist _ 0 NPeano.Nat.lt_0_1)).
+ vector_fold_right1_dep f init sing v = sing (v (exist _ 0 Nat.lt_0_1)).
Proof.
unfold vector_fold_right1_dep.
simpl.
@@ -677,7 +677,7 @@ Section Vector.
- intros [[i pf] eqqi].
unfold vector_to_list in *.
rewrite vector_fold_right_Sn; simpl.
- destruct (NPeano.Nat.eq_dec i n).
+ destruct (Nat.eq_dec i n).
+ left.
unfold vlast.
subst.
@@ -697,7 +697,7 @@ Section Vector.
unfold vcons.
split.
- intros [[i pf] eqq].
- destruct (NPeano.Nat.eq_dec i n).
+ destruct (Nat.eq_dec i n).
+ subst; eauto.
+ right.
eexists (exist _ i _).
@@ -708,10 +708,10 @@ Section Vector.
- intros [eqq | inn].
+ red.
eexists (exist _ n _).
- destruct (NPeano.Nat.eq_dec n n); congruence.
+ destruct (Nat.eq_dec n n); congruence.
+ destruct inn as [[i pf] eqq].
eexists (exist _ i _).
- destruct (NPeano.Nat.eq_dec i n); [lia | ].
+ destruct (Nat.eq_dec i n); [lia | ].
erewrite index_pf_irrel; eauto.
Unshelve.
simpl; lia.
@@ -777,7 +777,7 @@ Section Vector.
- lia.
- rewrite vector_fold_right_dep_Sn.
simpl.
- destruct (NPeano.Nat.eq_dec i n).
+ destruct (Nat.eq_dec i n).
+ subst.
unfold vlast.
erewrite index_pf_irrel; eauto.
@@ -855,7 +855,7 @@ Section Vector.
Qed.
Definition bounded_seq (start len : nat) : list {n':nat | n' < start+len}%nat
- := bounded_seq_bounded start len len (Le.le_refl _).
+ := bounded_seq_bounded start len len (Nat.le_refl _).
Definition bounded_seq0 len : list {n':nat | n' < len}%nat := bounded_seq 0 len.
@@ -918,7 +918,7 @@ Section Vector.
+ rewrite vector_fold_right_Sn.
intros [Plast Pdrop].
intros [i pf].
- destruct (NPeano.Nat.eq_dec i n).
+ destruct (Nat.eq_dec i n).
* unfold vlast in Plast.
subst.
erewrite index_pf_irrel; eauto.
@@ -956,7 +956,7 @@ Section Vector.
specialize (IHn _ _ eqq2).
rewrite eqq1.
unfold vcons.
- destruct (NPeano.Nat.eq_dec i n); trivial.
+ destruct (Nat.eq_dec i n); trivial.
Qed.
Lemma vectoro_to_ovector_forall_some_b {A n} (vo:Vector (option A) n) (v:Vector A n) :
@@ -1049,7 +1049,7 @@ Section Vector.
- intros eqq.
rewrite vector_fold_right_dep_Sn.
unfold vlast.
- destruct (NPeano.Nat.eq_dec i n).
+ destruct (Nat.eq_dec i n).
+ subst.
erewrite index_pf_irrel.
rewrite eqq; simpl; trivial.
@@ -1066,7 +1066,7 @@ Section Vector.
intros [i pf2].
apply v.
exists i.
- eapply NPeano.Nat.lt_le_trans; eassumption.
+ eapply Nat.lt_le_trans; eassumption.
Defined.
Lemma vfirstn0 {T} {n} (v:Vector T n) pf : vfirstn v 0 pf = vnil.
@@ -1127,10 +1127,10 @@ Section Vector.
sing (v (exist (fun n' : nat => (n' < S m)%nat) 0%nat pf1))
| S bound2 =>
fun (an' : A (S bound2)) (_ : (S (S bound2) <= S m)%nat) =>
- f (S bound2) (v (exist (fun n' : nat => (n' < S m)%nat) bound (Le.le_Sn_le (S bound) (S m) pf))) an'
+ f (S bound2) (v (exist (fun n' : nat => (n' < S m)%nat) bound (Nat.lt_le_incl (S bound) (S m) pf))) an'
end
(vector_fold_right1_bounded_dep f init sing v bound
- (Le.le_Sn_le bound (S m) (Le.le_Sn_le (S bound) (S m) pf))) (Le.le_Sn_le (S bound) (S m) pf)) with (vector_fold_right1_bounded_dep f init sing v (S bound) pf2); try eapply IHbound.
+ (Nat.lt_le_incl bound (S m) (Nat.lt_le_incl (S bound) (S m) pf))) (Nat.lt_le_incl (S bound) (S m) pf)) with (vector_fold_right1_bounded_dep f init sing v (S bound) pf2); try eapply IHbound.
clear.
destruct bound; simpl.
-- erewrite index_pf_irrel; eauto.
@@ -1159,7 +1159,7 @@ Section Vector.
forall {m:nat} (v:Vector B m), P m v (vector_fold_right1_dep f init sing v).
Proof.
intros.
- rewrite <- (vfirstn_eq v (Le.le_refl m)) at 1.
+ rewrite <- (vfirstn_eq v (Nat.le_refl m)) at 1.
apply vector_fold_right1_bounded_dep_gen_ind; trivial.
Qed.
@@ -1171,8 +1171,9 @@ Section Vector.
Lemma vtake_skip_app_eq_pf n m (pf:(n<=m)%nat) : n + (m - n) = m.
Proof.
- rewrite NPeano.Nat.add_sub_assoc by trivial.
- now rewrite Minus.minus_plus.
+ rewrite Nat.add_sub_assoc by trivial.
+ rewrite Nat.add_comm.
+ apply Nat.add_sub.
Defined.
Lemma vtake_skip_app_lt_pf {m n i} (pf:(n<=m)%nat) (p2f:i < m) : i < n + (m - n).
diff --git a/coq/NeuralNetworks/derivlemmas.v b/rocq/utils/derivlemmas.v
similarity index 100%
rename from coq/NeuralNetworks/derivlemmas.v
rename to rocq/utils/derivlemmas.v
diff --git a/coq/utils/improper_integrals.v b/rocq/utils/improper_integrals.v
similarity index 100%
rename from coq/utils/improper_integrals.v
rename to rocq/utils/improper_integrals.v
diff --git a/coq/utils/quotient_space.v b/rocq/utils/quotient_space.v
similarity index 100%
rename from coq/utils/quotient_space.v
rename to rocq/utils/quotient_space.v