OCaml 工具链安装教程
2026年4月19日
现在,我们按照下面官方给出的教程,编写几个简单的程序进行ocaml工具链的验证:
- ocaml:https://ocaml.org/docs/your-first-program
- rocq:https://rocq-prover.org/doc/V9.2.0/refman/index.html
- why3:https://why3.org/doc/exec.html
- moonbit:https://docs.moonbitlang.com/zh-cn/latest/tutorial/tour.html
接着,在Visual Studio Code的Extensions中配置如下环境:
- OCaml Platform v2.0.1
- VsRocq v2.4.3
- MoonBit Language v0.7.2026041003
配置后,首先对ocaml本身进行验证,因为rocq和why3都是基于ocaml的,因此我们应当首先保证ocaml开发环境的正确,逐步输入如下命令,切换到你所希望的目录,创建一个project:
cd ~
mkdir ./ocaml_workspace
mkdir ./ocaml_workspace/ocaml
cd ./ocaml_workspace/ocaml
opam exec -- dune init proj hello输入完成后,在对应路径下应当会生成一个名为hello的项目文件夹
该文件夹结构下对应路径的解释如下:
hello #项目名称
├── bin #主程序目录
│ ├── dune #s-exp格式的dune配置文件
│ └── main.ml #主程序代码文件
├── _build #构建输出目录
│ └── log #构建日志
├── dune-project #项目的dune配置文件
├── hello.opam #opam包定义文件
├── lib #引用库目录
│ └── dune #s-exp格式的dune配置文件
└── test #测试用例目录
├── dune #s-exp格式的dune配置文件
└── hello.ml #测试用例代码文件作为一门命令型的函数式程序设计语言,ocaml没有主函数入口.相对的,其代码的逻辑入口可以通过一个带副作用的匿名函数+let的组合实现.
打开处于./ocaml_workspace/ocaml/hello/bin/main.ml的代码,得到如下的结果:

接下来我们继续输入以下命令,对程序进行构建:
cd ./hello
opam exec -- dune exec hello出现如上的内容,就说明我们的第一个ocaml程序构建并运行成功了.
在ocaml中,每一个文件编译后都会产生一个module.接下来,我们来关注多文件程序项目与接口相关的内容,输入如下命令:
> ./lib/en.mli > ./lib/en.ml
> ./lib/es.mli > ./lib/es.ml
cd ./lib
ls此时理想的文件情况应该如下:
点进./lib/en.mli与./lib/es.mli,输入如下的相同内容:
/hello/lib/en.mli & hello/lib/es.mli
val v : string然后如下修改./lib/en.ml与./lib/es.ml中的内容:
/hello/lib/en.ml
let hello = "Hello";;
let v = hello ^ ", OCaml!";;/hello/lib/es.ml
let v = "Hello FP!"在./bin/main.ml中如下所示修改代码:
/hello/bin/main.ml
let mainFunc (num) : int =
Printf.printf "%s %d\n" Hello.Es.v num;
Printf.printf "%s\n" Hello.En.v;
num
;;
let () =
ignore (mainFunc 20260419)
;;最后重新执行构建,逐步输入以下命令:
cd ..
opam exec -- dune exec hello
得到如上结果就说明我们在./lib处编写的两个module已经成功构建并且运行了.
接着我们来看到测试用例部分,来到./test/test_hello.ml文件:
/hello/test/test_hello.ml
let test_en=Hello.En.v;;
let test_es=Hello.Es.v;;
let () =
Printf.printf "%s\n" test_en;
Printf.printf "%s\n" test_es
;;在测试用例代码编写完成后,如下进行测试:
dune test出现如上提示则说明测试成功,使用cd命令返回ocaml_workspace目录.
接下来让我们看到rocq,首先,在先前Visual Studio Code配置扩展的基础上,如下步骤安装vsrocq-language-server并验证可用:
opam install coq-core
opam install vsrocq-language-server
which vsrocqtop
sudo echo "{\"vsrocq.path\": \"$(which vsrocqtop)\"}" > ../.vscode/settings.json
opam exec -- dune init proj hellorocq
cd ./hellorocq
> ./test_rocq.v先来到./test_rocq.v文件:
/test_rocq.v
Fixpoint eqb (n m : nat) : bool :=
match n with
| O => match m with
| O => true
| S m' => false
end
| S n' => match m with
| O => false
| S m' => eqb n' m'
end
end.
Notation "x =? y" := (eqb x y)
(at level 70) : nat_scope.
Theorem zero_nbeq_plus_1 : forall n : nat,
0 =? (n + 1) = false.
Proof.
intros n.
destruct n eqn:Sn.
- reflexivity.
- reflexivity.
Qed.接着,使用右上角的VsRocq进行逐步验证:

按"Step Forward(alt+↓)"可以逐步进行验证,按"Interpret To End(alt+end)"可以直接一步验证完整个.v文件,其他几个按键同理类推.下面是上面情况下验证一步的结果:

Rocq-prover是一种基于元编程与类型理论的ITP(交互式证明) PA(证明助手),在逐步进行命题/定理/引理等相关证明时(特别是来到tactic时),可以通过右侧工作区同步观察对应命题形式变形的情况.
例如,我们继续上面的证明:
| /test_rocq.v | Rocq 目标状态 |
|---|---|
| |
| |
| |
| |
| |
| |
| |
| |
上面为我们演示了一个使用tactics完成Theorem证明的案例.
接下来让我们来看到如何使用dune来构建一个Rocq项目,并且它是如何与OCaml协作的.如前所示,我们已经创建了一个新dune项目,现在来到./hellorocq目录下:
cd ./hellorocq
mkdir ./theories
> ./theories/Order.v > ./theories/VerifiedMergeSort.v
> ./theories/dune > ./lib/Extract.v
> ./lib/hellorocq.ml > ./hellorocq.mli接着逐步打开目录下各个文件进行添加修改或替换,首先在./dune-project上添加以下一行并修改package中的depends以让dune支持rocq相关环境:
/hellorocq/dune-project
(using rocq 0.12)
...
(package ...
(depends ocaml dune rcoq)
...)
...接着,对目录下的其他dune文件进行修改替换,主要是./lib/dune与./theories/dune两个文件:
/hellorocq/lib/dune
(library
(name hellorocq)
(flags
(:standard -w -33))
(modules hellorocq Datatypes Orders PeanoNat Mergesort Order VerifiedMergeSort))
(rocq.extraction
(stdlib no)
(prelude Extract)
(extracted_modules Datatypes Orders PeanoNat Mergesort Order VerifiedMergeSort)
(theories Corelib Stdlib RocqmlVerified))/hellorocq/theories/dune
(rocq.theory
(stdlib no)
(name HelloRocqVerified)
(modules_flags
(bar (:standard \ -quiet)))
(package hellorocq)
(modules Order VerifiedMergeSort)
(theories Corelib Stdlib))接着,针对涉及到的.v, .ml与.mli文件进行代码编写,参考如下:
/hellorocq/bin/main.ml
let show_int_list xs =
let body = xs |> List.map string_of_int |> String.concat "; " in
"[" ^ body ^ "]"
;;
let default_input = [ 5; 3; 8; 1; 3; 0; 2 ];;
let input =
match Array.to_list Sys.argv with
| _ :: [] -> default_input
| _ :: args -> List.map int_of_string args
| [] -> default_input
;;
let () =
let sorted = Hellorocq.sort_nat input in
Printf.printf "%s\n" (show_int_list sorted)
;;/hellorocq/lib/hellorocq.mli
val sort_nat : int list -> int list/hellorocq/lib/hellococq.ml
let require_nat n =
if n < 0 then invalid_arg "Rocqml.sort_nat: expected non-negative integers"
;;
let sort_nat xs =
List.iter require_nat xs;
VerifiedMergeSort.sort_nat xs
;;/hellorocq/test/test_hellorocq.ml
let show_int_list xs =
let body = xs |> List.map string_of_int |> String.concat "; " in
"[" ^ body ^ "]"
;;
let check_sort name input expected =
let actual = Hellorocq.sort_nat input in
if actual <> expected then
failwith
(Printf.sprintf "%s: expected %s, got %s" name (show_int_list expected)
(show_int_list actual))
else
(Printf.printf "[%s] passed \n" name)
;;
let () =
check_sort "empty" [] [];
check_sort "singleton" [ 7 ] [ 7 ];
check_sort "duplicates and boundaries" [ 5; 3; 8; 1; 3; 0; 2 ]
[ 0; 1; 2; 3; 3; 5; 8 ]
;;下面主要是一些.v格式的rocq脚本:
/hellorocq/lib/Extract.v
From Stdlib Require Import extraction.Extraction.
From Stdlib Require Import extraction.ExtrOcamlBasic.
From Stdlib Require Import extraction.ExtrOcamlNatInt.
From RocqmlVerified Require Import VerifiedMergeSort.
Set Extraction Output Directory "./".
Extraction Language OCaml.
Separate Extraction sort_nat./hellorocq/theories/Order.v
From Corelib Require Import Init.Prelude.
From Stdlib Require Import Arith.PeanoNat micromega.Lia Structures.Orders.
Module NatOrder <: TotalLeBool.
Definition t := nat.
Definition leb := Nat.leb.
Infix "<=?" := leb (at level 70, no associativity).
Theorem leb_total : forall x y, (x <=? y) = true \/ (y <=? x) = true.
Proof.
intros x y.
unfold leb.
destruct (Nat.leb x y) eqn:Hxy; auto.
right.
apply Nat.leb_gt in Hxy.
apply Nat.leb_le.
lia.
Qed.
End NatOrder./hellorocq/theories/VerifiedMergeSort.v
From Corelib Require Import Init.Prelude.
From Stdlib Require Import List Sorting.Mergesort Sorting.Permutation Sorting.Sorted.
From RocqmlVerified Require Import Order.
Import ListNotations.
Module NatMergeSort := Mergesort.Sort NatOrder.
Definition sort_nat : list nat -> list nat := NatMergeSort.sort.
Theorem sort_nat_permutation : forall xs, Permutation xs (sort_nat xs).
Proof.
intro xs.
unfold sort_nat.
apply NatMergeSort.Permuted_sort.
Qed.
Theorem sort_nat_sorted :
forall xs, Sorted (fun x y => Nat.leb x y = true) (sort_nat xs).
Proof.
intro xs.
unfold sort_nat.
apply NatMergeSort.Sorted_sort.
Qed.
Example sort_nat_example :
sort_nat [5; 3; 8; 1; 3; 0; 2] = [0; 1; 2; 3; 3; 5; 8].
Proof.
vm_compute.
reflexivity.
Qed.到此便可以对相关程序进行测试验证了,如下输入对应命令:
dune build
opam exec -- dune exec hellorocq
dune test
至此,对rocq的验证也完成了,让我们回到ocaml_workspace目录,继续来到why3.
sudo yum install gtk3-devel gtksourceview3-devel
opam install alt-ergo z3 cvc5
opam install why3-coq coq-flocq coq-interval
why3 config detect
why3 config list-provers安装完成后,得到的结果应该如下所示:

配置完环境后,如下输入命令,编写几个简单的例子来验证why3:
cd ./why3
> ./why2ocaml.mlw > ./why2c.mlw > ./why2java.mlw如下配置对应的程序代码:
/why2c.mlw
use int.Int
use map.Map as Map
use mach.c.C
use mach.int.Int32
use mach.int.Int64
function ([]) (a: ptr 'a) (i: int): 'a = Map.get a.data.Array.elts (a.offset + i)
type r = { mutable x : int32; mutable y : int32 }
let locate_max (a: ptr int64) (n: int32): int32
requires { 0 < n }
requires { valid a n }
ensures { 0 <= result < n }
ensures { forall i. 0 <= i < n -> a[i] <= a[result] }
= let ref idx = 0 in
for j = 1 to n - 1 do
invariant { 0 <= idx < n }
invariant { forall i. 0 <= i < j -> a[i] <= a[idx] }
if get_ofs a idx < get_ofs a j then idx <- j
done;
idx
let swap (a : r) : unit =
let tmp = a.y in a.y <- a.x; a.x <- tmp/why2ocaml.mlw
module MaxAndSum
use int.Int
use ref.Ref
use array.Array
let max_sum (a: array int) (n: int) : (int, int)
requires { n = length a }
requires { forall i. 0 <= i < n -> a[i] >= 0 }
returns { sum, max -> sum <= n * max }
= let sum = ref 0 in
let max = ref 0 in
for i = 0 to n - 1 do
invariant { !sum <= i * !max }
if !max < a[i] then max := a[i];
sum := !sum + a[i]
done;
!sum, !max
let test () =
let n = 10 in
let a = (make n 0) in
a[0] <- 9; a[1] <- 5;
a[2] <- 0; a[3] <- 2;
a[4] <- 7; a[5] <- 3;
a[6] <- 2; a[7] <- 1;
a[8] <- 10; a[9] <- 6;
max_sum a n
end/why2java.mlw
module Employee
use mach.java.lang.Integer
use mach.java.lang.String
use mach.java.util.Map
type service_id = HUMAN_RES | TECHNICAL | BUSINESS
type t = {
name: string;
room: integer;
phone : string;
service : service_id;
}
let create_employee [@java:constructor]
(name : string) (room : integer)
(phone : string) (service : service_id) = {
name = name;
room = room;
phone = phone;
service = service;
}
end
module EmployeeAlreadyExistsException [@java:exception:RuntimeException]
use mach.java.lang.String
type t [@extraction:preserve_single_field] = { msg : string }
exception E t
let constructor[@java:constructor](name : string) : t = {
msg = (String.format_1 "Employee '%s' already exists" name)
}
let getMessage(self : t) : string = self.msg
end
module Directory
use int.Int
use mach.java.lang.String
use mach.java.lang.Integer
use mach.java.util.Map
use Employee
use EmployeeAlreadyExistsException
type t [@extraction:preserve_single_field]= {
employees [@java:visibility:private] : Map.map string Employee.t
}
let create_directory [@java:constructor] () : t = {
employees = Map.empty()
}
let add_employee (self : t) (name : string)
(phone : string) (room : integer)
(service : service_id) : unit
ensures { Map.containsKey self.employees name }
ensures { let m = Map.get self.employees name in
m.name = name && m.phone = phone &&
m.room = room && m.service = service }
raises { EmployeeAlreadyExistsException.E _ ->
Map.containsKey (old self.employees) name }
=
if Map.containsKey self.employees name then
raise (EmployeeAlreadyExistsException.E (constructor name));
Map.put self.employees name (
Employee.create_employee name room phone service)
endwhy3的程序结构比较简单,接下来我们通过以下的命令尝试测试运行并extract代码文件中的WhyML为不同语言的代码:
why3 execute why2ocaml.mlw --use=MaxAndSum 'test ()'
why3 extract -D ocaml64 why2ocaml.mlw -o why2ocaml.ml
why3 extract -D c why2c.mlw -o why2c.h
why3 extract -L . -o . -D java --recursive --modular why2java.mlw
ls
提取完代码后,应该能够看到转换为其他语言代码的WhyML程序:

接下来让我们尝试一下如何使用Why3调用ITP Prover(如Rocq)进行证明:
> ./coq4why.why按照下面的示例代码进行测试:
/coq4why.why
theory HelloProof
use int.Int
goal G1: true
goal G2: (true -> false) /\ (true \/ false)
goal G3: forall x:int. x * x >= 0
end接着如下进行逐步调试,下面的步骤将引导你创建session->调用prover证明->replay所得到记录的session.
why3 session create -o ./coq4whysession ./coq4why.why
why3 prove -P z3 -o ./coq4whysession ./coq4why.why
why3 replay ./coq4whysession但是由于其目前无GUI相关支持似乎不是很好,这里让我们回避这个问题.

我们继续来看到Why3 API是如何作为OCaml的一部分发挥作用的:
> ./why4ocaml.ml创建文件后,填充以下代码:
/why4ocaml.ml
open Why3
open Format
let fmla_true : Term.term = Term.t_true;;
let fmla_false : Term.term = Term.t_false;;
let fmla1 : Term.term = Term.t_or fmla_true fmla_false;;
let prop_var_A : Term.lsymbol =
Term.create_psymbol (Ident.id_fresh "A") [];;
let prop_var_B : Term.lsymbol =
Term.create_psymbol (Ident.id_fresh "B") [];;
let atom_A : Term.term =
Term.ps_app prop_var_A [];;
let atom_B : Term.term =
Term.ps_app prop_var_B [];;
let fmla2 : Term.term =
Term.t_implies (Term.t_and atom_A atom_B) atom_A;;
let task1 : Task.task = None
let goal_id1 : Decl.prsymbol =
Decl.create_prsymbol (Ident.id_fresh "goal1");;
let task1 : Task.task =
Task.add_prop_decl task1 Decl.Pgoal goal_id1 fmla1;;
let task2 : Task.task = None
let task2 : Task.task =
Task.add_param_decl task2 prop_var_A;;
let task2 : Task.task =
Task.add_param_decl task2 prop_var_B;;
let goal_id2 =
Decl.create_prsymbol (Ident.id_fresh "goal2");;
let task2 = Task.add_prop_decl task2 Decl.Pgoal goal_id2 fmla2;;
let config : Whyconf.config =
Whyconf.init_config None;;
let main : Whyconf.main =
Whyconf.get_main config;;
let provers : Whyconf.config_prover Whyconf.Mprover.t =
Whyconf.get_provers config;;
let limits =
Call_provers.{empty_limits with
limit_time = Whyconf.timelimit main;
limit_mem = Whyconf.memlimit main }
let alt_ergo : Whyconf.config_prover =
let fp = Whyconf.parse_filter_prover "Alt-Ergo" in
(* all provers that have the name "Alt-Ergo" *)
let provers = Whyconf.filter_provers config fp in
if Whyconf.Mprover.is_empty provers then begin
eprintf "Prover Alt-Ergo not installed or not configured@.";
exit 1
end else begin
printf "Versions of Alt-Ergo found:";
Whyconf.(Mprover.iter (fun k _ -> printf " %s" k.prover_version) provers);
printf "@.";
(* returning an arbitrary one *)
snd (Whyconf.Mprover.max_binding provers)
end
let env : Env.env =
Env.create_env (Whyconf.loadpath main);;
let alt_ergo_driver : Driver.driver =
try
Driver.load_driver_for_prover main env alt_ergo
with e ->
eprintf "Failed to load driver for alt-ergo: %a@."
Exn_printer.exn_printer e;
exit 1
;;
let call_task (task) : Call_provers.prover_result =
Call_provers.wait_on_call
(Driver.prove_task
~limits
~config:main
~command:alt_ergo.Whyconf.command
alt_ergo_driver
task)
;;
let () =
printf "@[formula 1 is:@ %a@]@." Pretty.print_term fmla1;
printf "@[formula 2 is:@ %a@]@." Pretty.print_term fmla2;
printf "@[task 1 is:@\n%a@]@." Pretty.print_task task1;
printf "@[task 2 created:@\n%a@]@." Pretty.print_task task2;
printf "@[On task 1, Alt-Ergo answers %a@]@."
(Call_provers.print_prover_result ?json:(Some false)) (call_task task1);
printf "@[On task 2, Alt-Ergo answers %a@]@."
(Call_provers.print_prover_result ?json:(Some false)) (call_task task2)
;;我们通过如下的方式对其进行构建并且运行(如果你乐意的话,你也可以试试dune):
ocamlfind ocamlc -package why3 -linkpkg why4ocaml.ml -o why4ocaml
./why4ocaml得到的结果应该如下图所示:

至此,我们完成了对Why3环境的验证,该是时候回到ocaml_workspace并且尝试一下使用国产的moonbit语言了.

相对而言,由于基于Why3的形式化验证功能尚在起步阶段,我们尚且不对其进行讨论,因而,这里对moonbit功能的验证比较简单:
cd ./moonbit
moon new neural_network
cd ./netural_network接着配置示例代码:
/neural_network/neural_network.mbt
struct Data {
input : Array[Double]
expected : Int
}
let train_dataset : Array[Data] = [
{ input: [5.1, 3.5, 1.4, 0.2], expected: 0 }, { input: [4.9, 3.0, 1.4, 0.2], expected: 0 },
{ input: [4.7, 3.2, 1.3, 0.2], expected: 0 }, { input: [4.6, 3.1, 1.5, 0.2], expected: 0 },
{ input: [5.0, 3.6, 1.4, 0.2], expected: 0 }, { input: [5.4, 3.9, 1.7, 0.4], expected: 0 },
{ input: [4.6, 3.4, 1.4, 0.3], expected: 0 }, { input: [5.4, 3.7, 1.5, 0.2], expected: 0 },
{ input: [4.8, 3.4, 1.6, 0.2], expected: 0 }, { input: [4.8, 3.0, 1.4, 0.1], expected: 0 },
{ input: [4.3, 3.0, 1.1, 0.1], expected: 0 }, { input: [5.8, 4.0, 1.2, 0.2], expected: 0 },
{ input: [5.7, 4.4, 1.5, 0.4], expected: 0 }, { input: [5.4, 3.9, 1.3, 0.4], expected: 0 },
{ input: [5.1, 3.5, 1.4, 0.3], expected: 0 }, { input: [5.7, 3.8, 1.7, 0.3], expected: 0 },
{ input: [5.1, 3.8, 1.5, 0.3], expected: 0 }, { input: [7.0, 3.2, 4.7, 1.4], expected: 1 },
{ input: [6.1, 2.8, 4.7, 1.2], expected: 1 }, { input: [6.4, 2.9, 4.3, 1.3], expected: 1 },
{ input: [6.6, 3.0, 4.4, 1.4], expected: 1 }, { input: [6.8, 2.8, 4.8, 1.4], expected: 1 },
{ input: [6.7, 3.0, 5.0, 1.7], expected: 1 }, { input: [6.4, 2.7, 5.3, 1.9], expected: 2 },
{ input: [6.8, 3.0, 5.5, 2.1], expected: 2 }, { input: [5.7, 2.5, 5.0, 2.0], expected: 2 },
{ input: [5.8, 2.8, 5.1, 2.4], expected: 2 }, { input: [6.0, 2.9, 4.5, 1.5], expected: 1 },
{ input: [5.7, 2.6, 3.5, 1.0], expected: 1 }, { input: [5.5, 2.4, 3.8, 1.1], expected: 1 },
{ input: [5.5, 2.4, 3.7, 1.0], expected: 1 }, { input: [5.8, 2.7, 3.9, 1.2], expected: 1 },
{ input: [6.0, 2.7, 5.1, 1.6], expected: 1 }, { input: [5.4, 3.0, 4.5, 1.5], expected: 1 },
{ input: [6.0, 3.4, 4.5, 1.6], expected: 1 }, { input: [6.7, 3.1, 4.7, 1.5], expected: 1 },
{ input: [6.3, 2.3, 4.4, 1.3], expected: 1 }, { input: [5.6, 3.0, 4.1, 1.3], expected: 1 },
{ input: [5.5, 2.5, 4.0, 1.3], expected: 1 }, { input: [5.5, 2.6, 4.4, 1.2], expected: 1 },
{ input: [6.1, 3.0, 4.6, 1.4], expected: 1 }, { input: [5.8, 2.6, 4.0, 1.2], expected: 1 },
{ input: [5.0, 2.3, 3.3, 1.0], expected: 1 }, { input: [5.6, 2.7, 4.2, 1.3], expected: 1 },
{ input: [5.7, 3.0, 4.2, 1.2], expected: 1 }, { input: [5.0, 3.4, 1.5, 0.2], expected: 0 },
{ input: [4.4, 2.9, 1.4, 0.2], expected: 0 }, { input: [4.9, 3.1, 1.5, 0.1], expected: 0 },
{ input: [5.7, 2.9, 4.2, 1.3], expected: 1 }, { input: [6.2, 2.9, 4.3, 1.3], expected: 1 },
{ input: [5.1, 2.5, 3.0, 1.1], expected: 1 }, { input: [5.7, 2.8, 4.1, 1.3], expected: 1 },
{ input: [6.3, 3.3, 6.0, 2.5], expected: 2 }, { input: [5.8, 2.7, 5.1, 1.9], expected: 2 },
{ input: [7.1, 3.0, 5.9, 2.1], expected: 2 }, { input: [6.3, 2.9, 5.6, 1.8], expected: 2 },
{ input: [6.5, 3.0, 5.8, 2.2], expected: 2 }, { input: [7.6, 3.0, 6.6, 2.1], expected: 2 },
{ input: [4.9, 2.5, 4.5, 1.7], expected: 2 }, { input: [7.3, 2.9, 6.3, 1.8], expected: 2 },
{ input: [6.7, 2.5, 5.8, 1.8], expected: 2 }, { input: [7.2, 3.6, 6.1, 2.5], expected: 2 },
{ input: [6.5, 3.2, 5.1, 2.0], expected: 2 }, { input: [6.4, 3.2, 5.3, 2.3], expected: 2 },
{ input: [6.5, 3.0, 5.5, 1.8], expected: 2 }, { input: [7.7, 3.8, 6.7, 2.2], expected: 2 },
{ input: [7.7, 2.6, 6.9, 2.3], expected: 2 }, { input: [6.0, 2.2, 5.0, 1.5], expected: 2 },
]
let test_dataset : Array[Data] = [
{ input: [5.4, 3.4, 1.7, 0.2], expected: 0 }, { input: [5.1, 3.7, 1.5, 0.4], expected: 0 },
{ input: [4.6, 3.6, 1.0, 0.2], expected: 0 }, { input: [5.1, 3.3, 1.7, 0.5], expected: 0 },
{ input: [4.8, 3.4, 1.9, 0.2], expected: 0 }, { input: [5.0, 3.0, 1.6, 0.2], expected: 0 },
{ input: [5.0, 3.4, 1.6, 0.4], expected: 0 }, { input: [5.2, 3.5, 1.5, 0.2], expected: 0 },
{ input: [5.2, 3.4, 1.4, 0.2], expected: 0 }, { input: [4.7, 3.2, 1.6, 0.2], expected: 0 },
{ input: [4.8, 3.1, 1.6, 0.2], expected: 0 }, { input: [5.4, 3.4, 1.5, 0.4], expected: 0 },
{ input: [5.2, 4.1, 1.5, 0.1], expected: 0 }, { input: [5.5, 4.2, 1.4, 0.2], expected: 0 },
{ input: [4.9, 3.1, 1.5, 0.1], expected: 0 }, { input: [5.0, 3.2, 1.2, 0.2], expected: 0 },
{ input: [5.5, 3.5, 1.3, 0.2], expected: 0 }, { input: [4.9, 3.1, 1.5, 0.1], expected: 0 },
{ input: [4.4, 3.0, 1.3, 0.2], expected: 0 }, { input: [5.1, 3.4, 1.5, 0.2], expected: 0 },
{ input: [5.0, 3.5, 1.3, 0.3], expected: 0 }, { input: [4.5, 2.3, 1.3, 0.3], expected: 0 },
{ input: [4.4, 3.2, 1.3, 0.2], expected: 0 }, { input: [5.0, 3.5, 1.6, 0.6], expected: 0 },
{ input: [5.1, 3.8, 1.9, 0.4], expected: 0 }, { input: [4.8, 3.0, 1.4, 0.3], expected: 0 },
{ input: [5.1, 3.8, 1.6, 0.2], expected: 0 }, { input: [4.6, 3.2, 1.4, 0.2], expected: 0 },
{ input: [5.3, 3.7, 1.5, 0.2], expected: 0 }, { input: [5.0, 3.3, 1.4, 0.2], expected: 0 },
{ input: [6.4, 3.2, 4.5, 1.5], expected: 1 }, { input: [6.9, 3.1, 4.9, 1.5], expected: 1 },
{ input: [5.5, 2.3, 4.0, 1.3], expected: 1 }, { input: [6.5, 2.8, 4.6, 1.5], expected: 1 },
{ input: [5.7, 2.8, 4.5, 1.3], expected: 1 }, { input: [6.3, 3.3, 4.7, 1.6], expected: 1 },
{ input: [4.9, 2.4, 3.3, 1.0], expected: 1 }, { input: [6.6, 2.9, 4.6, 1.3], expected: 1 },
{ input: [5.2, 2.7, 3.9, 1.4], expected: 1 }, { input: [5.0, 2.0, 3.5, 1.0], expected: 1 },
{ input: [5.9, 3.0, 4.2, 1.5], expected: 1 }, { input: [6.0, 2.2, 4.0, 1.0], expected: 1 },
{ input: [6.1, 2.9, 4.7, 1.4], expected: 1 }, { input: [5.6, 2.9, 3.6, 1.3], expected: 1 },
{ input: [6.7, 3.1, 4.4, 1.4], expected: 1 }, { input: [5.6, 3.0, 4.5, 1.5], expected: 1 },
{ input: [5.8, 2.7, 4.1, 1.0], expected: 1 }, { input: [6.2, 2.2, 4.5, 1.5], expected: 1 },
{ input: [5.6, 2.5, 3.9, 1.1], expected: 1 }, { input: [5.9, 3.2, 4.8, 1.8], expected: 1 },
{ input: [6.1, 2.8, 4.0, 1.3], expected: 1 }, { input: [6.3, 2.5, 4.9, 1.5], expected: 1 },
{ input: [6.9, 3.2, 5.7, 2.3], expected: 2 }, { input: [5.6, 2.8, 4.9, 2.0], expected: 2 },
{ input: [7.7, 2.8, 6.7, 2.0], expected: 2 }, { input: [6.3, 2.7, 4.9, 1.8], expected: 2 },
{ input: [6.7, 3.3, 5.7, 2.1], expected: 2 }, { input: [7.2, 3.2, 6.0, 1.8], expected: 2 },
{ input: [6.2, 2.8, 4.8, 1.8], expected: 2 }, { input: [6.1, 3.0, 4.9, 1.8], expected: 2 },
{ input: [6.4, 2.8, 5.6, 2.1], expected: 2 }, { input: [7.2, 3.0, 5.8, 1.6], expected: 2 },
{ input: [7.4, 2.8, 6.1, 1.9], expected: 2 }, { input: [7.9, 3.8, 6.4, 2.0], expected: 2 },
{ input: [6.4, 2.8, 5.6, 2.2], expected: 2 }, { input: [6.3, 2.8, 5.1, 1.5], expected: 2 },
{ input: [6.1, 2.6, 5.6, 1.4], expected: 2 }, { input: [7.7, 3.0, 6.1, 2.3], expected: 2 },
{ input: [6.3, 3.4, 5.6, 2.4], expected: 2 }, { input: [6.4, 3.1, 5.5, 1.8], expected: 2 },
{ input: [6.0, 3.0, 4.8, 1.8], expected: 2 }, { input: [6.9, 3.1, 5.4, 2.1], expected: 2 },
{ input: [6.7, 3.1, 5.6, 2.4], expected: 2 }, { input: [6.9, 3.1, 5.1, 2.3], expected: 2 },
{ input: [5.8, 2.7, 5.1, 1.9], expected: 2 }, { input: [6.8, 3.2, 5.9, 2.3], expected: 2 },
{ input: [6.7, 3.3, 5.7, 2.5], expected: 2 }, { input: [6.7, 3.0, 5.2, 2.3], expected: 2 },
{ input: [6.3, 2.5, 5.0, 1.9], expected: 2 }, { input: [6.5, 3.0, 5.2, 2.0], expected: 2 },
{ input: [6.2, 3.4, 5.4, 2.3], expected: 2 }, { input: [5.9, 3.0, 5.1, 1.8], expected: 2 },
]
fn[T : Base] reLU(t : T) -> T {
if t.value() < 0.0 {
T::constant(0.0)
} else {
t
}
}
fn[T : Base] softmax(inputs : Array[T]) -> Array[T] {
let n = inputs.length()
let sum = inputs.fold(init=T::constant(0.0), (acc, input) => acc + input.exp())
let outputs : Array[T] = Array::makei(n, i => inputs[i].exp() / sum)
outputs
}
fn[T : Base] input2hidden(
inputs : Array[Double],
param : Array[Array[T]],
) -> Array[T] {
let outputs : Array[T] = Array::makei(param.length(), o => {
reLU(
inputs.foldi(init=T::constant(0.0), (index, acc, input) => {
acc + T::constant(input) * param[o][index]
}) +
param[o][inputs.length()],
)
})
outputs
}
fn[T : Base] hidden2output(
inputs : Array[T],
param : Array[Array[T]],
) -> Array[T] {
let outputs : Array[T] = Array::makei(param.length(), o => {
inputs.foldi(init=T::constant(0.0), (index, acc, input) => {
acc + input * param[o][index]
}) +
param[o][inputs.length()]
})
outputs |> softmax
}
pub struct Backward {
value : Double
propagate : () -> Unit
backward : (Double) -> Unit
}
impl Base for Backward with constant(d : Double) -> Backward {
{ value: d, propagate: fn() { }, backward: _ => () }
}
fn Backward::backward(b : Backward) -> Unit {
(b.propagate)()
(b.backward)(1.0)
}
impl Base for Backward with value(backward : Backward) -> Double {
backward.value
}
impl Add for Backward with add(b1 : Backward, b2 : Backward) -> Backward {
let counter = Ref::{ val: 0 }
let cumul = Ref::{ val: 0.0 }
{
value: b1.value + b2.value,
propagate: fn() {
counter.val = counter.val + 1
if counter.val == 1 {
(b1.propagate)()
(b2.propagate)()
}
},
backward: fn(diff) {
counter.val = counter.val - 1
cumul.val = cumul.val + diff
if counter.val == 0 {
(b1.backward)(cumul.val)
(b2.backward)(cumul.val)
}
},
}
}
impl Neg for Backward with neg(b : Backward) -> Backward {
let counter = Ref::{ val: 0 }
let cumul = Ref::{ val: 0.0 }
{
value: -b.value,
propagate: fn() {
counter.val = counter.val + 1
if counter.val == 1 {
(b.propagate)()
}
},
backward: fn(diff) {
counter.val = counter.val - 1
cumul.val = cumul.val + diff
if counter.val == 0 {
(b.backward)(-cumul.val)
}
},
}
}
impl Mul for Backward with mul(b1 : Backward, b2 : Backward) -> Backward {
let counter = Ref::{ val: 0 }
let cumul = Ref::{ val: 0.0 }
{
value: b1.value * b2.value,
propagate: fn() {
counter.val = counter.val + 1
if counter.val == 1 {
(b1.propagate)()
(b2.propagate)()
}
},
backward: fn(diff) {
counter.val = counter.val - 1
cumul.val = cumul.val + diff
if counter.val == 0 {
(b1.backward)(cumul.val * b2.value)
(b2.backward)(cumul.val * b1.value)
}
},
}
}
impl Div for Backward with div(b1 : Backward, b2 : Backward) -> Backward {
let counter = Ref::{ val: 0 }
let cumul = Ref::{ val: 0.0 }
{
value: b1.value / b2.value,
propagate: fn() {
counter.val = counter.val + 1
if counter.val == 1 {
(b1.propagate)()
(b2.propagate)()
}
},
backward: fn(diff) {
counter.val = counter.val - 1
cumul.val = cumul.val + diff
if counter.val == 0 {
(b1.backward)(cumul.val / b2.value)
(b2.backward)(-cumul.val * b1.value / b2.value / b2.value)
}
},
}
}
trait Base: Add + Neg + Mul + Div {
constant(Double) -> Self
value(Self) -> Double
exp(Self) -> Self
}
impl Base for Double with constant(b : Double) -> Double {
b
}
impl Base for Double with value(b : Double) -> Double {
b
}
impl Base for Double with exp(b : Double) -> Double {
@math.exp(b)
}
impl Base for Backward with exp(b : Backward) -> Backward {
let b_exp = Base::exp(b.value)
let counter = Ref::{ val: 0 }
let cumul = Ref::{ val: 0.0 }
{
value: b_exp,
propagate: fn() {
counter.val = counter.val + 1
if counter.val == 1 {
(b.propagate)()
}
},
backward: fn(diff) {
counter.val = counter.val - 1
cumul.val = cumul.val + diff
if counter.val == 0 {
(b.backward)(cumul.val * b_exp)
}
},
}
}
trait Log {
log(Self) -> Self
}
impl Log for Double with log(b : Double) -> Double {
@math.log2(b)
}
impl Log for Backward with log(b : Backward) -> Backward {
let counter = Ref::{ val: 0 }
let cumul = Ref::{ val: 0.0 }
{
value: Log::log(b.value),
propagate: fn() {
counter.val = counter.val + 1
if counter.val == 1 {
(b.propagate)()
}
},
backward: fn(diff) {
counter.val = counter.val - 1
cumul.val = cumul.val + diff
if counter.val == 0 {
(b.backward)(cumul.val / b.value)
}
},
}
}
fn[T : Base + Log] cross_entropy(inputs : Array[T], expected : Int) -> T {
-inputs[expected].log()
}
fn[T : Base] compute(
inputs : Array[Double],
param_hidden : Array[Array[T]],
param_output : Array[Array[T]],
) -> Array[T] {
inputs |> input2hidden(param_hidden) |> hidden2output(param_output)
}
fn[N : Base] verify(result : Array[N], expected : Int) -> Bool {
match expected {
0 =>
return result[0].value() > result[1].value() &&
result[0].value() > result[2].value()
1 =>
return result[1].value() > result[0].value() &&
result[1].value() > result[2].value()
2 =>
return result[2].value() > result[1].value() &&
result[2].value() > result[0].value()
_ => abort("")
}
}
fn diff(
inputs : Array[Double],
expected : Int,
param_hidden : Array[Array[Backward]],
param_output : Array[Array[Backward]],
) -> Unit {
let result = compute(inputs, param_hidden, param_output)
|> cross_entropy(expected)
result.backward()
}
fn update(
params : Array[Array[Double]],
diff : Array[Array[Double]],
step : Double,
) -> Unit {
for i = 0; i < params.length(); i = i + 1 {
for j = 0; j < params[i].length(); j = j + 1 {
params[i][j] -= step * diff[i][j]
}
}
}
fn train(
param_hidden : Array[Array[Double]],
param_output : Array[Array[Double]],
step : Double,
) -> Unit {
let diff_hidden : Array[Array[Double]] = Array::makei(param_hidden.length(), i => {
Array::make(param_hidden[i].length(), 0.0)
})
let backward_hidden : Array[Array[Backward]] = Array::makei(
param_hidden.length(),
i => {
Array::makei(param_hidden[i].length(), j => {
value: param_hidden[i][j],
propagate: fn() { },
backward: d => diff_hidden[i][j] += d,
})
},
)
let diff_output : Array[Array[Double]] = Array::makei(param_output.length(), i => {
Array::make(param_output[i].length(), 0.0)
})
let backward_output : Array[Array[Backward]] = Array::makei(
param_output.length(),
i => {
Array::makei(param_output[i].length(), j => {
value: param_output[i][j],
propagate: fn() { },
backward: d => diff_output[i][j] += d,
})
},
)
for i = 0; i < train_dataset.length(); i = i + 1 {
diff(
train_dataset[i].input,
train_dataset[i].expected,
backward_hidden,
backward_output,
)
}
update(param_hidden, diff_hidden, step)
update(param_output, diff_output, step)
}
fn main {
let hidden_layer : Array[Array[Double]] = Array::makei(4, _ => {
Array::make(5, 0.2)
})
let output_layer : Array[Array[Double]] = Array::makei(3, _ => {
Array::make(5, 0.2)
})
let learning_rate = 0.00028
let decay_rate = -0.001
for i = 0; i < 400; i = i + 1 {
train(
hidden_layer,
output_layer,
learning_rate * Base::exp(decay_rate * i.to_double()),
)
}
for i = 0, correct = 0, wrong = 0; i < test_dataset.length(); {
let result = compute(test_dataset[i].input, hidden_layer, output_layer)
println(
"epoch: \{i} Acc: \{correct.to_double()/(wrong.to_double()+correct.to_double())}",
)
if verify(result, test_dataset[i].expected) {
continue i + 1, correct + 1, wrong
} else {
continue i + 1, correct, wrong + 1
}
} nobreak {
println("Correct: \{correct} Wrong: \{wrong}")
println(
"Final Acc: \{correct.to_double()/(wrong.to_double()+correct.to_double())}",
)
}
}代码配置完成后,参照以下命令进行运行:
moon run cmd/main运行后,应当见到以下的结果:

至此,我们对相关功能的验证都告一段落,祝你好运.