Skip to content

OCaml 工具链安装教程

2026年4月19日

现在,我们按照下面官方给出的教程,编写几个简单的程序进行ocaml工具链的验证:

接着,在Visual Studio Code的Extensions中配置如下环境:

  • OCaml Platform v2.0.1
  • VsRocq v2.4.3
  • MoonBit Language v0.7.2026041003

配置后,首先对ocaml本身进行验证,因为rocq和why3都是基于ocaml的,因此我们应当首先保证ocaml开发环境的正确,逐步输入如下命令,切换到你所希望的目录,创建一个project:

bash
cd ~
mkdir ./ocaml_workspace
mkdir ./ocaml_workspace/ocaml
cd ./ocaml_workspace/ocaml
opam exec -- dune init proj hello

输入完成后,在对应路径下应当会生成一个名为hello的项目文件夹

该文件夹结构下对应路径的解释如下:

text
hello                        #项目名称
├── bin                     #主程序目录
│   ├── dune               #s-exp格式的dune配置文件
│   └── main.ml             #主程序代码文件
├── _build                   #构建输出目录
│   └── log                 #构建日志
├── dune-project              #项目的dune配置文件
├── hello.opam               #opam包定义文件
├── lib                      #引用库目录
│   └── dune               #s-exp格式的dune配置文件
└── test                     #测试用例目录
    ├── dune                #s-exp格式的dune配置文件
    └── hello.ml             #测试用例代码文件

作为一门命令型的函数式程序设计语言,ocaml没有主函数入口.相对的,其代码的逻辑入口可以通过一个带副作用的匿名函数+let的组合实现.

打开处于./ocaml_workspace/ocaml/hello/bin/main.ml的代码,得到如下的结果:

接下来我们继续输入以下命令,对程序进行构建:

bash
cd ./hello
opam exec -- dune exec hello

出现如上的内容,就说明我们的第一个ocaml程序构建并运行成功了.

在ocaml中,每一个文件编译后都会产生一个module.接下来,我们来关注多文件程序项目与接口相关的内容,输入如下命令:

bash
> ./lib/en.mli > ./lib/en.ml
> ./lib/es.mli > ./lib/es.ml
cd ./lib
ls

此时理想的文件情况应该如下:

点进./lib/en.mli与./lib/es.mli,输入如下的相同内容:

/hello/lib/en.mli & hello/lib/es.mli

ocaml
val v : string

然后如下修改./lib/en.ml与./lib/es.ml中的内容:

/hello/lib/en.ml

ocaml
let hello = "Hello";;
let v = hello ^ ", OCaml!";;

/hello/lib/es.ml

ocaml
let v = "Hello FP!"

在./bin/main.ml中如下所示修改代码:

/hello/bin/main.ml

ocaml
let mainFunc (num) : int =
  Printf.printf "%s %d\n" Hello.Es.v num;
  Printf.printf "%s\n" Hello.En.v;
  num
;;
let () =
  ignore (mainFunc 20260419)
;;

最后重新执行构建,逐步输入以下命令:

bash
cd ..
opam exec -- dune exec hello

得到如上结果就说明我们在./lib处编写的两个module已经成功构建并且运行了.

接着我们来看到测试用例部分,来到./test/test_hello.ml文件:

/hello/test/test_hello.ml

ocaml
let test_en=Hello.En.v;;
let test_es=Hello.Es.v;;
let () =
  Printf.printf "%s\n" test_en;
  Printf.printf "%s\n" test_es
;;

在测试用例代码编写完成后,如下进行测试:

bash
dune test

出现如上提示则说明测试成功,使用cd命令返回ocaml_workspace目录.

接下来让我们看到rocq,首先,在先前Visual Studio Code配置扩展的基础上,如下步骤安装vsrocq-language-server并验证可用:

bash
opam install coq-core
opam install vsrocq-language-server
which vsrocqtop
sudo echo "{\"vsrocq.path\": \"$(which vsrocqtop)\"}" > ../.vscode/settings.json
opam exec -- dune init proj hellorocq
cd ./hellorocq
> ./test_rocq.v

先来到./test_rocq.v文件:

/test_rocq.v

coq
Fixpoint eqb (n m : nat) : bool :=
  match n with
    | O => match m with
      | O => true
      | S m' => false
    end
    | S n' => match m with
      | O => false
      | S m' => eqb n' m'
    end
  end.
Notation "x =? y" := (eqb x y)
  (at level 70) : nat_scope.
Theorem zero_nbeq_plus_1 : forall n : nat,
  0 =? (n + 1) = false.
Proof.
  intros n.
  destruct n eqn:Sn.
  - reflexivity.
  - reflexivity.
Qed.

接着,使用右上角的VsRocq进行逐步验证:

按"Step Forward(alt+↓)"可以逐步进行验证,按"Interpret To End(alt+end)"可以直接一步验证完整个.v文件,其他几个按键同理类推.下面是上面情况下验证一步的结果:

Rocq-prover是一种基于元编程与类型理论的ITP(交互式证明) PA(证明助手),在逐步进行命题/定理/引理等相关证明时(特别是来到tactic时),可以通过右侧工作区同步观察对应命题形式变形的情况.

例如,我们继续上面的证明:

/test_rocq.vRocq 目标状态
Theorem zero_nbeq_plus_1 : forall n : nat,
  0 =? (n + 1) = false.
Proof.
  intros n.
  destruct n eqn:Sn.
  - reflexivity.
  - reflexivity.
Qed.
Goal 1
(1 / 1)
forall n : nat, (0 =? n + 1) = false
Theorem zero_nbeq_plus_1 : forall n : nat,
  0 =? (n + 1) = false.
Proof.
  intros n.
  destruct n eqn:Sn.
  - reflexivity.
  - reflexivity.
Qed.
Goal 1
n : nat
(1 / 1)
(0 =? n + 1) = false
Theorem zero_nbeq_plus_1 : forall n : nat,
  0 =? (n + 1) = false.
Proof.
  intros n.
  destruct n eqn:Sn.
  - reflexivity.
  - reflexivity.
Qed.
Goal 1
n : nat
Sn : n = 0
(1 / 2)
(0 =? 0 + 1) = false
Goal 2
(2 / 2)
(0 =? S n0 + 1) = false
Theorem zero_nbeq_plus_1 : forall n : nat,
  0 =? (n + 1) = false.
Proof.
  intros n.
  destruct n eqn:Sn.
  - reflexivity.
  - reflexivity.
Qed.
Goal 1
n : nat
Sn : n = 0
(1 / 1)
(0 =? 0 + 1) = false
Theorem zero_nbeq_plus_1 : forall n : nat,
  0 =? (n + 1) = false.
Proof.
  intros n.
  destruct n eqn:Sn.
  - reflexivity.
  - reflexivity.
Qed.
The subproof is complete.
Next unfocused goals (focus with bullet):
Goal 1
Theorem zero_nbeq_plus_1 : forall n : nat,
  0 =? (n + 1) = false.
Proof.
  intros n.
  destruct n eqn:Sn.
  - reflexivity.
  - reflexivity.
Qed.
Goal 1
n, n0 : nat
Sn : n = S n0
(1 / 1)
(0 =? S n0 + 1) = false
Theorem zero_nbeq_plus_1 : forall n : nat,
  0 =? (n + 1) = false.
Proof.
  intros n.
  destruct n eqn:Sn.
  - reflexivity.
  - reflexivity.
Qed.
There are no more subgoals
Theorem zero_nbeq_plus_1 : forall n : nat,
  0 =? (n + 1) = false.
Proof.
  intros n.
  destruct n eqn:Sn.
  - reflexivity.
  - reflexivity.
Qed.
Not in proof mode
zero_nbeq_plus_1 is defined

上面为我们演示了一个使用tactics完成Theorem证明的案例.

接下来让我们来看到如何使用dune来构建一个Rocq项目,并且它是如何与OCaml协作的.如前所示,我们已经创建了一个新dune项目,现在来到./hellorocq目录下:

bash
cd ./hellorocq
mkdir ./theories
> ./theories/Order.v > ./theories/VerifiedMergeSort.v
> ./theories/dune > ./lib/Extract.v
> ./lib/hellorocq.ml > ./hellorocq.mli

接着逐步打开目录下各个文件进行添加修改或替换,首先在./dune-project上添加以下一行并修改package中的depends以让dune支持rocq相关环境:

/hellorocq/dune-project

lisp
(using rocq 0.12)
...
(package ...
 (depends ocaml dune rcoq)
 ...)
...

接着,对目录下的其他dune文件进行修改替换,主要是./lib/dune与./theories/dune两个文件:

/hellorocq/lib/dune

lisp
(library
 (name hellorocq)
 (flags
  (:standard -w -33))
 (modules hellorocq Datatypes Orders PeanoNat Mergesort Order VerifiedMergeSort))

(rocq.extraction
 (stdlib no)
 (prelude Extract)
 (extracted_modules Datatypes Orders PeanoNat Mergesort Order VerifiedMergeSort)
 (theories Corelib Stdlib RocqmlVerified))

/hellorocq/theories/dune

lisp
(rocq.theory
 (stdlib no)
 (name HelloRocqVerified)
 (modules_flags
  (bar (:standard \ -quiet)))
 (package hellorocq)
 (modules Order VerifiedMergeSort)
 (theories Corelib Stdlib))

接着,针对涉及到的.v, .ml与.mli文件进行代码编写,参考如下:

/hellorocq/bin/main.ml

ocaml
let show_int_list xs =
  let body = xs |> List.map string_of_int |> String.concat "; " in
  "[" ^ body ^ "]"
;;
let default_input = [ 5; 3; 8; 1; 3; 0; 2 ];;
let input =
  match Array.to_list Sys.argv with
  | _ :: [] -> default_input
  | _ :: args -> List.map int_of_string args
  | [] -> default_input
;;
let () =
  let sorted = Hellorocq.sort_nat input in
  Printf.printf "%s\n" (show_int_list sorted)
;;

/hellorocq/lib/hellorocq.mli

ocaml
val sort_nat : int list -> int list

/hellorocq/lib/hellococq.ml

ocaml
let require_nat n =
  if n < 0 then invalid_arg "Rocqml.sort_nat: expected non-negative integers"
;;
let sort_nat xs =
  List.iter require_nat xs;
  VerifiedMergeSort.sort_nat xs
;;

/hellorocq/test/test_hellorocq.ml

ocaml
let show_int_list xs =
  let body = xs |> List.map string_of_int |> String.concat "; " in
  "[" ^ body ^ "]"
;;
let check_sort name input expected =
  let actual = Hellorocq.sort_nat input in
  if actual <> expected then
    failwith
      (Printf.sprintf "%s: expected %s, got %s" name (show_int_list expected)
         (show_int_list actual))
  else
    (Printf.printf "[%s] passed \n" name)
;;
let () =
  check_sort "empty" [] [];
  check_sort "singleton" [ 7 ] [ 7 ];
  check_sort "duplicates and boundaries" [ 5; 3; 8; 1; 3; 0; 2 ]
    [ 0; 1; 2; 3; 3; 5; 8 ]
;;

下面主要是一些.v格式的rocq脚本:

/hellorocq/lib/Extract.v

coq
From Stdlib Require Import extraction.Extraction.
From Stdlib Require Import extraction.ExtrOcamlBasic.
From Stdlib Require Import extraction.ExtrOcamlNatInt.
From RocqmlVerified Require Import VerifiedMergeSort.

Set Extraction Output Directory "./".
Extraction Language OCaml.
Separate Extraction sort_nat.

/hellorocq/theories/Order.v

coq
From Corelib Require Import Init.Prelude.
From Stdlib Require Import Arith.PeanoNat micromega.Lia Structures.Orders.

Module NatOrder <: TotalLeBool.
  Definition t := nat.
  Definition leb := Nat.leb.
  Infix "<=?" := leb (at level 70, no associativity).
  Theorem leb_total : forall x y, (x <=? y) = true \/ (y <=? x) = true.
  Proof.
    intros x y.
    unfold leb.
    destruct (Nat.leb x y) eqn:Hxy; auto.
    right.
    apply Nat.leb_gt in Hxy.
    apply Nat.leb_le.
    lia.
  Qed.
End NatOrder.

/hellorocq/theories/VerifiedMergeSort.v

coq
From Corelib Require Import Init.Prelude.
From Stdlib Require Import List Sorting.Mergesort Sorting.Permutation Sorting.Sorted.
From RocqmlVerified Require Import Order.

Import ListNotations.
Module NatMergeSort := Mergesort.Sort NatOrder.
Definition sort_nat : list nat -> list nat := NatMergeSort.sort.
Theorem sort_nat_permutation : forall xs, Permutation xs (sort_nat xs).
Proof.
  intro xs.
  unfold sort_nat.
  apply NatMergeSort.Permuted_sort.
Qed.
Theorem sort_nat_sorted :
  forall xs, Sorted (fun x y => Nat.leb x y = true) (sort_nat xs).
Proof.
  intro xs.
  unfold sort_nat.
  apply NatMergeSort.Sorted_sort.
Qed.
Example sort_nat_example :
  sort_nat [5; 3; 8; 1; 3; 0; 2] = [0; 1; 2; 3; 3; 5; 8].
Proof.
  vm_compute.
  reflexivity.
Qed.

到此便可以对相关程序进行测试验证了,如下输入对应命令:

bash
dune build
opam exec -- dune exec hellorocq
dune test

至此,对rocq的验证也完成了,让我们回到ocaml_workspace目录,继续来到why3.

bash
sudo yum install gtk3-devel gtksourceview3-devel
opam install alt-ergo z3 cvc5
opam install why3-coq coq-flocq coq-interval
why3 config detect
why3 config list-provers

安装完成后,得到的结果应该如下所示:

配置完环境后,如下输入命令,编写几个简单的例子来验证why3:

bash
cd ./why3
> ./why2ocaml.mlw > ./why2c.mlw > ./why2java.mlw

如下配置对应的程序代码:

/why2c.mlw

text
use int.Int
use map.Map as Map
use mach.c.C
use mach.int.Int32
use mach.int.Int64

function ([]) (a: ptr 'a) (i: int): 'a = Map.get a.data.Array.elts (a.offset + i)
type r = { mutable x : int32; mutable y : int32 }
let locate_max (a: ptr int64) (n: int32): int32
  requires { 0 < n }
  requires { valid a n }
  ensures  { 0 <= result < n }
  ensures  { forall i. 0 <= i < n -> a[i] <= a[result] }
= let ref idx = 0 in
  for j = 1 to n - 1 do
    invariant { 0 <= idx < n }
    invariant { forall i. 0 <= i < j -> a[i] <= a[idx] }
    if get_ofs a idx < get_ofs a j then idx <- j
  done;
  idx
let swap (a : r) : unit =
   let tmp = a.y in a.y <- a.x; a.x <- tmp

/why2ocaml.mlw

text
module MaxAndSum
  use int.Int
  use ref.Ref
  use array.Array

  let max_sum (a: array int) (n: int) : (int, int)
    requires { n = length a }
    requires { forall i. 0 <= i < n -> a[i] >= 0 }
    returns  { sum, max -> sum <= n * max }
  = let sum = ref 0 in
    let max = ref 0 in
    for i = 0 to n - 1 do
      invariant { !sum <= i * !max }
      if !max < a[i] then max := a[i];
      sum := !sum + a[i]
    done;
    !sum, !max
  let test () =
    let n = 10 in
      let a = (make n 0) in
        a[0] <- 9; a[1] <- 5;
        a[2] <- 0; a[3] <- 2;
        a[4] <- 7; a[5] <- 3;
        a[6] <- 2; a[7] <- 1;
        a[8] <- 10; a[9] <- 6;
  max_sum a n
end

/why2java.mlw

text
module Employee
  use mach.java.lang.Integer
  use mach.java.lang.String
  use mach.java.util.Map

  type service_id = HUMAN_RES | TECHNICAL | BUSINESS
  type t = {
    name: string;
    room: integer;
    phone : string;
    service : service_id;
  }
  let create_employee [@java:constructor]
    (name : string) (room : integer)
    (phone : string) (service : service_id) = {
    name = name;
    room = room;
    phone = phone;
    service = service;
  }
end

module EmployeeAlreadyExistsException [@java:exception:RuntimeException]
  use mach.java.lang.String
  type t [@extraction:preserve_single_field] = { msg : string }
  exception E t
  let constructor[@java:constructor](name : string) : t = {
    msg = (String.format_1 "Employee '%s' already exists" name)
  }
  let getMessage(self : t) : string = self.msg
end

module Directory
  use int.Int
  use mach.java.lang.String
  use mach.java.lang.Integer
  use mach.java.util.Map
  use Employee
  use EmployeeAlreadyExistsException

  type t [@extraction:preserve_single_field]= {
    employees [@java:visibility:private] : Map.map string Employee.t
  }
  let create_directory [@java:constructor] () : t = {
    employees = Map.empty()
  }
  let add_employee (self : t) (name : string)
    (phone : string) (room : integer)
    (service : service_id) : unit
    ensures { Map.containsKey self.employees name }
    ensures { let m = Map.get self.employees name in
               m.name = name && m.phone = phone &&
               m.room = room  && m.service = service }
    raises { EmployeeAlreadyExistsException.E _ ->
        Map.containsKey (old self.employees) name }
  =
    if Map.containsKey self.employees name then
      raise (EmployeeAlreadyExistsException.E (constructor name));
    Map.put self.employees name (
        Employee.create_employee name room phone service)
end

why3的程序结构比较简单,接下来我们通过以下的命令尝试测试运行并extract代码文件中的WhyML为不同语言的代码:

bash
why3 execute why2ocaml.mlw --use=MaxAndSum 'test ()'
why3 extract -D ocaml64 why2ocaml.mlw -o why2ocaml.ml
why3 extract -D c why2c.mlw -o why2c.h
why3 extract -L . -o . -D java --recursive --modular why2java.mlw
ls

提取完代码后,应该能够看到转换为其他语言代码的WhyML程序:

接下来让我们尝试一下如何使用Why3调用ITP Prover(如Rocq)进行证明:

bash
> ./coq4why.why

按照下面的示例代码进行测试:

/coq4why.why

text
theory HelloProof
  use int.Int
  goal G1: true
  goal G2: (true -> false) /\ (true \/ false)
  goal G3: forall x:int. x * x >= 0
end

接着如下进行逐步调试,下面的步骤将引导你创建session->调用prover证明->replay所得到记录的session.

bash
why3 session create -o ./coq4whysession ./coq4why.why
why3 prove -P z3 -o ./coq4whysession ./coq4why.why
why3 replay ./coq4whysession

但是由于其目前无GUI相关支持似乎不是很好,这里让我们回避这个问题.

我们继续来看到Why3 API是如何作为OCaml的一部分发挥作用的:

bash
> ./why4ocaml.ml

创建文件后,填充以下代码:

/why4ocaml.ml

ocaml
open Why3
open Format

let fmla_true : Term.term = Term.t_true;;
let fmla_false : Term.term = Term.t_false;;
let fmla1 : Term.term = Term.t_or fmla_true fmla_false;;

let prop_var_A : Term.lsymbol =
  Term.create_psymbol (Ident.id_fresh "A") [];;
let prop_var_B : Term.lsymbol =
  Term.create_psymbol (Ident.id_fresh "B") [];;

let atom_A : Term.term =
  Term.ps_app prop_var_A [];;
let atom_B : Term.term =
  Term.ps_app prop_var_B [];;
let fmla2 : Term.term =
  Term.t_implies (Term.t_and atom_A atom_B) atom_A;;

let task1 : Task.task = None
let goal_id1 : Decl.prsymbol =
  Decl.create_prsymbol (Ident.id_fresh "goal1");;
let task1 : Task.task =
  Task.add_prop_decl task1 Decl.Pgoal goal_id1 fmla1;;
let task2 : Task.task = None
let task2 : Task.task =
  Task.add_param_decl task2 prop_var_A;;
let task2 : Task.task =
  Task.add_param_decl task2 prop_var_B;;
let goal_id2 =
  Decl.create_prsymbol (Ident.id_fresh "goal2");;
let task2 = Task.add_prop_decl task2 Decl.Pgoal goal_id2 fmla2;;

let config : Whyconf.config =
  Whyconf.init_config None;;
let main : Whyconf.main =
  Whyconf.get_main config;;
let provers : Whyconf.config_prover Whyconf.Mprover.t =
  Whyconf.get_provers config;;
let limits =
  Call_provers.{empty_limits with
                limit_time = Whyconf.timelimit main;
                limit_mem = Whyconf.memlimit main }
let alt_ergo : Whyconf.config_prover =
  let fp = Whyconf.parse_filter_prover "Alt-Ergo" in
  (* all provers that have the name "Alt-Ergo" *)
  let provers = Whyconf.filter_provers config fp in
  if Whyconf.Mprover.is_empty provers then begin
      eprintf "Prover Alt-Ergo not installed or not configured@.";
      exit 1
    end else begin
      printf "Versions of Alt-Ergo found:";
      Whyconf.(Mprover.iter (fun k _ -> printf " %s" k.prover_version) provers);
      printf "@.";
      (* returning an arbitrary one *)
      snd (Whyconf.Mprover.max_binding provers)
    end
let env : Env.env =
  Env.create_env (Whyconf.loadpath main);;
let alt_ergo_driver : Driver.driver =
  try
    Driver.load_driver_for_prover main env alt_ergo
  with e ->
    eprintf "Failed to load driver for alt-ergo: %a@."
      Exn_printer.exn_printer e;
    exit 1
;;

let call_task (task) : Call_provers.prover_result =
  Call_provers.wait_on_call
  (Driver.prove_task
    ~limits
    ~config:main
    ~command:alt_ergo.Whyconf.command
    alt_ergo_driver
    task)
;;

let () =
  printf "@[formula 1 is:@ %a@]@." Pretty.print_term fmla1;
  printf "@[formula 2 is:@ %a@]@." Pretty.print_term fmla2;
  printf "@[task 1 is:@\n%a@]@." Pretty.print_task task1;
  printf "@[task 2 created:@\n%a@]@." Pretty.print_task task2;
  printf "@[On task 1, Alt-Ergo answers %a@]@."
    (Call_provers.print_prover_result ?json:(Some false)) (call_task task1);
  printf "@[On task 2, Alt-Ergo answers %a@]@."
    (Call_provers.print_prover_result ?json:(Some false)) (call_task task2)
;;

我们通过如下的方式对其进行构建并且运行(如果你乐意的话,你也可以试试dune):

bash
ocamlfind ocamlc -package why3 -linkpkg why4ocaml.ml -o why4ocaml
./why4ocaml

得到的结果应该如下图所示:

至此,我们完成了对Why3环境的验证,该是时候回到ocaml_workspace并且尝试一下使用国产的moonbit语言了.

相对而言,由于基于Why3的形式化验证功能尚在起步阶段,我们尚且不对其进行讨论,因而,这里对moonbit功能的验证比较简单:

bash
cd ./moonbit
moon new neural_network
cd ./netural_network

接着配置示例代码:

/neural_network/neural_network.mbt

moonbit
struct Data {
  input : Array[Double]
  expected : Int
}
let train_dataset : Array[Data] = [
  { input: [5.1, 3.5, 1.4, 0.2], expected: 0 }, { input: [4.9, 3.0, 1.4, 0.2], expected: 0 },
  { input: [4.7, 3.2, 1.3, 0.2], expected: 0 }, { input: [4.6, 3.1, 1.5, 0.2], expected: 0 },
  { input: [5.0, 3.6, 1.4, 0.2], expected: 0 }, { input: [5.4, 3.9, 1.7, 0.4], expected: 0 },
  { input: [4.6, 3.4, 1.4, 0.3], expected: 0 }, { input: [5.4, 3.7, 1.5, 0.2], expected: 0 },
  { input: [4.8, 3.4, 1.6, 0.2], expected: 0 }, { input: [4.8, 3.0, 1.4, 0.1], expected: 0 },
  { input: [4.3, 3.0, 1.1, 0.1], expected: 0 }, { input: [5.8, 4.0, 1.2, 0.2], expected: 0 },
  { input: [5.7, 4.4, 1.5, 0.4], expected: 0 }, { input: [5.4, 3.9, 1.3, 0.4], expected: 0 },
  { input: [5.1, 3.5, 1.4, 0.3], expected: 0 }, { input: [5.7, 3.8, 1.7, 0.3], expected: 0 },
  { input: [5.1, 3.8, 1.5, 0.3], expected: 0 }, { input: [7.0, 3.2, 4.7, 1.4], expected: 1 },
  { input: [6.1, 2.8, 4.7, 1.2], expected: 1 }, { input: [6.4, 2.9, 4.3, 1.3], expected: 1 },
  { input: [6.6, 3.0, 4.4, 1.4], expected: 1 }, { input: [6.8, 2.8, 4.8, 1.4], expected: 1 },
  { input: [6.7, 3.0, 5.0, 1.7], expected: 1 }, { input: [6.4, 2.7, 5.3, 1.9], expected: 2 },
  { input: [6.8, 3.0, 5.5, 2.1], expected: 2 }, { input: [5.7, 2.5, 5.0, 2.0], expected: 2 },
  { input: [5.8, 2.8, 5.1, 2.4], expected: 2 }, { input: [6.0, 2.9, 4.5, 1.5], expected: 1 },
  { input: [5.7, 2.6, 3.5, 1.0], expected: 1 }, { input: [5.5, 2.4, 3.8, 1.1], expected: 1 },
  { input: [5.5, 2.4, 3.7, 1.0], expected: 1 }, { input: [5.8, 2.7, 3.9, 1.2], expected: 1 },
  { input: [6.0, 2.7, 5.1, 1.6], expected: 1 }, { input: [5.4, 3.0, 4.5, 1.5], expected: 1 },
  { input: [6.0, 3.4, 4.5, 1.6], expected: 1 }, { input: [6.7, 3.1, 4.7, 1.5], expected: 1 },
  { input: [6.3, 2.3, 4.4, 1.3], expected: 1 }, { input: [5.6, 3.0, 4.1, 1.3], expected: 1 },
  { input: [5.5, 2.5, 4.0, 1.3], expected: 1 }, { input: [5.5, 2.6, 4.4, 1.2], expected: 1 },
  { input: [6.1, 3.0, 4.6, 1.4], expected: 1 }, { input: [5.8, 2.6, 4.0, 1.2], expected: 1 },
  { input: [5.0, 2.3, 3.3, 1.0], expected: 1 }, { input: [5.6, 2.7, 4.2, 1.3], expected: 1 },
  { input: [5.7, 3.0, 4.2, 1.2], expected: 1 }, { input: [5.0, 3.4, 1.5, 0.2], expected: 0 },
  { input: [4.4, 2.9, 1.4, 0.2], expected: 0 }, { input: [4.9, 3.1, 1.5, 0.1], expected: 0 },
  { input: [5.7, 2.9, 4.2, 1.3], expected: 1 }, { input: [6.2, 2.9, 4.3, 1.3], expected: 1 },
  { input: [5.1, 2.5, 3.0, 1.1], expected: 1 }, { input: [5.7, 2.8, 4.1, 1.3], expected: 1 },
  { input: [6.3, 3.3, 6.0, 2.5], expected: 2 }, { input: [5.8, 2.7, 5.1, 1.9], expected: 2 },
  { input: [7.1, 3.0, 5.9, 2.1], expected: 2 }, { input: [6.3, 2.9, 5.6, 1.8], expected: 2 },
  { input: [6.5, 3.0, 5.8, 2.2], expected: 2 }, { input: [7.6, 3.0, 6.6, 2.1], expected: 2 },
  { input: [4.9, 2.5, 4.5, 1.7], expected: 2 }, { input: [7.3, 2.9, 6.3, 1.8], expected: 2 },
  { input: [6.7, 2.5, 5.8, 1.8], expected: 2 }, { input: [7.2, 3.6, 6.1, 2.5], expected: 2 },
  { input: [6.5, 3.2, 5.1, 2.0], expected: 2 }, { input: [6.4, 3.2, 5.3, 2.3], expected: 2 },
  { input: [6.5, 3.0, 5.5, 1.8], expected: 2 }, { input: [7.7, 3.8, 6.7, 2.2], expected: 2 },
  { input: [7.7, 2.6, 6.9, 2.3], expected: 2 }, { input: [6.0, 2.2, 5.0, 1.5], expected: 2 },
]
let test_dataset : Array[Data] = [
  { input: [5.4, 3.4, 1.7, 0.2], expected: 0 }, { input: [5.1, 3.7, 1.5, 0.4], expected: 0 },
  { input: [4.6, 3.6, 1.0, 0.2], expected: 0 }, { input: [5.1, 3.3, 1.7, 0.5], expected: 0 },
  { input: [4.8, 3.4, 1.9, 0.2], expected: 0 }, { input: [5.0, 3.0, 1.6, 0.2], expected: 0 },
  { input: [5.0, 3.4, 1.6, 0.4], expected: 0 }, { input: [5.2, 3.5, 1.5, 0.2], expected: 0 },
  { input: [5.2, 3.4, 1.4, 0.2], expected: 0 }, { input: [4.7, 3.2, 1.6, 0.2], expected: 0 },
  { input: [4.8, 3.1, 1.6, 0.2], expected: 0 }, { input: [5.4, 3.4, 1.5, 0.4], expected: 0 },
  { input: [5.2, 4.1, 1.5, 0.1], expected: 0 }, { input: [5.5, 4.2, 1.4, 0.2], expected: 0 },
  { input: [4.9, 3.1, 1.5, 0.1], expected: 0 }, { input: [5.0, 3.2, 1.2, 0.2], expected: 0 },
  { input: [5.5, 3.5, 1.3, 0.2], expected: 0 }, { input: [4.9, 3.1, 1.5, 0.1], expected: 0 },
  { input: [4.4, 3.0, 1.3, 0.2], expected: 0 }, { input: [5.1, 3.4, 1.5, 0.2], expected: 0 },
  { input: [5.0, 3.5, 1.3, 0.3], expected: 0 }, { input: [4.5, 2.3, 1.3, 0.3], expected: 0 },
  { input: [4.4, 3.2, 1.3, 0.2], expected: 0 }, { input: [5.0, 3.5, 1.6, 0.6], expected: 0 },
  { input: [5.1, 3.8, 1.9, 0.4], expected: 0 }, { input: [4.8, 3.0, 1.4, 0.3], expected: 0 },
  { input: [5.1, 3.8, 1.6, 0.2], expected: 0 }, { input: [4.6, 3.2, 1.4, 0.2], expected: 0 },
  { input: [5.3, 3.7, 1.5, 0.2], expected: 0 }, { input: [5.0, 3.3, 1.4, 0.2], expected: 0 },
  { input: [6.4, 3.2, 4.5, 1.5], expected: 1 }, { input: [6.9, 3.1, 4.9, 1.5], expected: 1 },
  { input: [5.5, 2.3, 4.0, 1.3], expected: 1 }, { input: [6.5, 2.8, 4.6, 1.5], expected: 1 },
  { input: [5.7, 2.8, 4.5, 1.3], expected: 1 }, { input: [6.3, 3.3, 4.7, 1.6], expected: 1 },
  { input: [4.9, 2.4, 3.3, 1.0], expected: 1 }, { input: [6.6, 2.9, 4.6, 1.3], expected: 1 },
  { input: [5.2, 2.7, 3.9, 1.4], expected: 1 }, { input: [5.0, 2.0, 3.5, 1.0], expected: 1 },
  { input: [5.9, 3.0, 4.2, 1.5], expected: 1 }, { input: [6.0, 2.2, 4.0, 1.0], expected: 1 },
  { input: [6.1, 2.9, 4.7, 1.4], expected: 1 }, { input: [5.6, 2.9, 3.6, 1.3], expected: 1 },
  { input: [6.7, 3.1, 4.4, 1.4], expected: 1 }, { input: [5.6, 3.0, 4.5, 1.5], expected: 1 },
  { input: [5.8, 2.7, 4.1, 1.0], expected: 1 }, { input: [6.2, 2.2, 4.5, 1.5], expected: 1 },
  { input: [5.6, 2.5, 3.9, 1.1], expected: 1 }, { input: [5.9, 3.2, 4.8, 1.8], expected: 1 },
  { input: [6.1, 2.8, 4.0, 1.3], expected: 1 }, { input: [6.3, 2.5, 4.9, 1.5], expected: 1 },
  { input: [6.9, 3.2, 5.7, 2.3], expected: 2 }, { input: [5.6, 2.8, 4.9, 2.0], expected: 2 },
  { input: [7.7, 2.8, 6.7, 2.0], expected: 2 }, { input: [6.3, 2.7, 4.9, 1.8], expected: 2 },
  { input: [6.7, 3.3, 5.7, 2.1], expected: 2 }, { input: [7.2, 3.2, 6.0, 1.8], expected: 2 },
  { input: [6.2, 2.8, 4.8, 1.8], expected: 2 }, { input: [6.1, 3.0, 4.9, 1.8], expected: 2 },
  { input: [6.4, 2.8, 5.6, 2.1], expected: 2 }, { input: [7.2, 3.0, 5.8, 1.6], expected: 2 },
  { input: [7.4, 2.8, 6.1, 1.9], expected: 2 }, { input: [7.9, 3.8, 6.4, 2.0], expected: 2 },
  { input: [6.4, 2.8, 5.6, 2.2], expected: 2 }, { input: [6.3, 2.8, 5.1, 1.5], expected: 2 },
  { input: [6.1, 2.6, 5.6, 1.4], expected: 2 }, { input: [7.7, 3.0, 6.1, 2.3], expected: 2 },
  { input: [6.3, 3.4, 5.6, 2.4], expected: 2 }, { input: [6.4, 3.1, 5.5, 1.8], expected: 2 },
  { input: [6.0, 3.0, 4.8, 1.8], expected: 2 }, { input: [6.9, 3.1, 5.4, 2.1], expected: 2 },
  { input: [6.7, 3.1, 5.6, 2.4], expected: 2 }, { input: [6.9, 3.1, 5.1, 2.3], expected: 2 },
  { input: [5.8, 2.7, 5.1, 1.9], expected: 2 }, { input: [6.8, 3.2, 5.9, 2.3], expected: 2 },
  { input: [6.7, 3.3, 5.7, 2.5], expected: 2 }, { input: [6.7, 3.0, 5.2, 2.3], expected: 2 },
  { input: [6.3, 2.5, 5.0, 1.9], expected: 2 }, { input: [6.5, 3.0, 5.2, 2.0], expected: 2 },
  { input: [6.2, 3.4, 5.4, 2.3], expected: 2 }, { input: [5.9, 3.0, 5.1, 1.8], expected: 2 },
]
fn[T : Base] reLU(t : T) -> T {
  if t.value() < 0.0 {
    T::constant(0.0)
  } else {
    t
  }
}
fn[T : Base] softmax(inputs : Array[T]) -> Array[T] {
  let n = inputs.length()
  let sum = inputs.fold(init=T::constant(0.0), (acc, input) => acc + input.exp())
  let outputs : Array[T] = Array::makei(n, i => inputs[i].exp() / sum)
  outputs
}
fn[T : Base] input2hidden(
  inputs : Array[Double],
  param : Array[Array[T]],
) -> Array[T] {
  let outputs : Array[T] = Array::makei(param.length(), o => {
    reLU(
      inputs.foldi(init=T::constant(0.0), (index, acc, input) => {
        acc + T::constant(input) * param[o][index]
      }) +
      param[o][inputs.length()],
    )
  })
  outputs
}
fn[T : Base] hidden2output(
  inputs : Array[T],
  param : Array[Array[T]],
) -> Array[T] {
  let outputs : Array[T] = Array::makei(param.length(), o => {
    inputs.foldi(init=T::constant(0.0), (index, acc, input) => {
      acc + input * param[o][index]
    }) +
    param[o][inputs.length()]
  })
  outputs |> softmax
}
pub struct Backward {
  value : Double
  propagate : () -> Unit
  backward : (Double) -> Unit
}
impl Base for Backward with constant(d : Double) -> Backward {
  { value: d, propagate: fn() {  }, backward: _ => () }
}
fn Backward::backward(b : Backward) -> Unit {
  (b.propagate)()
  (b.backward)(1.0)
}
impl Base for Backward with value(backward : Backward) -> Double {
  backward.value
}
impl Add for Backward with add(b1 : Backward, b2 : Backward) -> Backward {
  let counter = Ref::{ val: 0 }
  let cumul = Ref::{ val: 0.0 }
  {
    value: b1.value + b2.value,
    propagate: fn() {
      counter.val = counter.val + 1
      if counter.val == 1 {
        (b1.propagate)()
        (b2.propagate)()
      }
    },
    backward: fn(diff) {
      counter.val = counter.val - 1
      cumul.val = cumul.val + diff
      if counter.val == 0 {
        (b1.backward)(cumul.val)
        (b2.backward)(cumul.val)
      }
    },
  }
}
impl Neg for Backward with neg(b : Backward) -> Backward {
  let counter = Ref::{ val: 0 }
  let cumul = Ref::{ val: 0.0 }
  {
    value: -b.value,
    propagate: fn() {
      counter.val = counter.val + 1
      if counter.val == 1 {
        (b.propagate)()
      }
    },
    backward: fn(diff) {
      counter.val = counter.val - 1
      cumul.val = cumul.val + diff
      if counter.val == 0 {
        (b.backward)(-cumul.val)
      }
    },
  }
}
impl Mul for Backward with mul(b1 : Backward, b2 : Backward) -> Backward {
  let counter = Ref::{ val: 0 }
  let cumul = Ref::{ val: 0.0 }
  {
    value: b1.value * b2.value,
    propagate: fn() {
      counter.val = counter.val + 1
      if counter.val == 1 {
        (b1.propagate)()
        (b2.propagate)()
      }
    },
    backward: fn(diff) {
      counter.val = counter.val - 1
      cumul.val = cumul.val + diff
      if counter.val == 0 {
        (b1.backward)(cumul.val * b2.value)
        (b2.backward)(cumul.val * b1.value)
      }
    },
  }
}
impl Div for Backward with div(b1 : Backward, b2 : Backward) -> Backward {
  let counter = Ref::{ val: 0 }
  let cumul = Ref::{ val: 0.0 }
  {
    value: b1.value / b2.value,
    propagate: fn() {
      counter.val = counter.val + 1
      if counter.val == 1 {
        (b1.propagate)()
        (b2.propagate)()
      }
    },
    backward: fn(diff) {
      counter.val = counter.val - 1
      cumul.val = cumul.val + diff
      if counter.val == 0 {
        (b1.backward)(cumul.val / b2.value)
        (b2.backward)(-cumul.val * b1.value / b2.value / b2.value)
      }
    },
  }
}
trait Base: Add + Neg + Mul + Div {
  constant(Double) -> Self
  value(Self) -> Double
  exp(Self) -> Self
}
impl Base for Double with constant(b : Double) -> Double {
  b
}
impl Base for Double with value(b : Double) -> Double {
  b
}
impl Base for Double with exp(b : Double) -> Double {
  @math.exp(b)
}
impl Base for Backward with exp(b : Backward) -> Backward {
  let b_exp = Base::exp(b.value)
  let counter = Ref::{ val: 0 }
  let cumul = Ref::{ val: 0.0 }
  {
    value: b_exp,
    propagate: fn() {
      counter.val = counter.val + 1
      if counter.val == 1 {
        (b.propagate)()
      }
    },
    backward: fn(diff) {
      counter.val = counter.val - 1
      cumul.val = cumul.val + diff
      if counter.val == 0 {
        (b.backward)(cumul.val * b_exp)
      }
    },
  }
}
trait Log {
  log(Self) -> Self
}
impl Log for Double with log(b : Double) -> Double {
  @math.log2(b)
}
impl Log for Backward with log(b : Backward) -> Backward {
  let counter = Ref::{ val: 0 }
  let cumul = Ref::{ val: 0.0 }
  {
    value: Log::log(b.value),
    propagate: fn() {
      counter.val = counter.val + 1
      if counter.val == 1 {
        (b.propagate)()
      }
    },
    backward: fn(diff) {
      counter.val = counter.val - 1
      cumul.val = cumul.val + diff
      if counter.val == 0 {
        (b.backward)(cumul.val / b.value)
      }
    },
  }
}
fn[T : Base + Log] cross_entropy(inputs : Array[T], expected : Int) -> T {
  -inputs[expected].log()
}
fn[T : Base] compute(
  inputs : Array[Double],
  param_hidden : Array[Array[T]],
  param_output : Array[Array[T]],
) -> Array[T] {
  inputs |> input2hidden(param_hidden) |> hidden2output(param_output)
}
fn[N : Base] verify(result : Array[N], expected : Int) -> Bool {
  match expected {
    0 =>
      return result[0].value() > result[1].value() &&
        result[0].value() > result[2].value()
    1 =>
      return result[1].value() > result[0].value() &&
        result[1].value() > result[2].value()
    2 =>
      return result[2].value() > result[1].value() &&
        result[2].value() > result[0].value()
    _ => abort("")
  }
}
fn diff(
  inputs : Array[Double],
  expected : Int,
  param_hidden : Array[Array[Backward]],
  param_output : Array[Array[Backward]],
) -> Unit {
  let result = compute(inputs, param_hidden, param_output)
    |> cross_entropy(expected)
  result.backward()
}
fn update(
  params : Array[Array[Double]],
  diff : Array[Array[Double]],
  step : Double,
) -> Unit {
  for i = 0; i < params.length(); i = i + 1 {
    for j = 0; j < params[i].length(); j = j + 1 {
      params[i][j] -= step * diff[i][j]
    }
  }
}
fn train(
  param_hidden : Array[Array[Double]],
  param_output : Array[Array[Double]],
  step : Double,
) -> Unit {
  let diff_hidden : Array[Array[Double]] = Array::makei(param_hidden.length(), i => {
    Array::make(param_hidden[i].length(), 0.0)
  })
  let backward_hidden : Array[Array[Backward]] = Array::makei(
    param_hidden.length(),
    i => {
      Array::makei(param_hidden[i].length(), j => {
        value: param_hidden[i][j],
        propagate: fn() {  },
        backward: d => diff_hidden[i][j] += d,
      })
    },
  )
  let diff_output : Array[Array[Double]] = Array::makei(param_output.length(), i => {
    Array::make(param_output[i].length(), 0.0)
  })
  let backward_output : Array[Array[Backward]] = Array::makei(
    param_output.length(),
    i => {
      Array::makei(param_output[i].length(), j => {
        value: param_output[i][j],
        propagate: fn() {  },
        backward: d => diff_output[i][j] += d,
      })
    },
  )
  for i = 0; i < train_dataset.length(); i = i + 1 {
    diff(
      train_dataset[i].input,
      train_dataset[i].expected,
      backward_hidden,
      backward_output,
    )
  }
  update(param_hidden, diff_hidden, step)
  update(param_output, diff_output, step)
}
fn main {
  let hidden_layer : Array[Array[Double]] = Array::makei(4, _ => {
    Array::make(5, 0.2)
  })
  let output_layer : Array[Array[Double]] = Array::makei(3, _ => {
    Array::make(5, 0.2)
  })
  let learning_rate = 0.00028
  let decay_rate = -0.001
  for i = 0; i < 400; i = i + 1 {
    train(
      hidden_layer,
      output_layer,
      learning_rate * Base::exp(decay_rate * i.to_double()),
    )
  }
  for i = 0, correct = 0, wrong = 0; i < test_dataset.length(); {
    let result = compute(test_dataset[i].input, hidden_layer, output_layer)
    println(
      "epoch: \{i} Acc: \{correct.to_double()/(wrong.to_double()+correct.to_double())}",
    )
    if verify(result, test_dataset[i].expected) {
      continue i + 1, correct + 1, wrong
    } else {
      continue i + 1, correct, wrong + 1
    }
  } nobreak {
    println("Correct: \{correct} Wrong: \{wrong}")
    println(
      "Final Acc: \{correct.to_double()/(wrong.to_double()+correct.to_double())}",
    )
  }
}

代码配置完成后,参照以下命令进行运行:

bash
moon run cmd/main

运行后,应当见到以下的结果:

至此,我们对相关功能的验证都告一段落,祝你好运.