theory rewriter
imports Main
begin
(* Output and tracing *)
ML {*
fun term_to_string trm =
Syntax.string_of_term (Config.put show_brackets true @{context}) trm
fun cterm_to_string ctrm = term_to_string (term_of ctrm)
val rewrite_trace = Attrib.setup_config_bool @{binding rewrite_trace} (K false)
fun cond_tracing msg =
if Config.get @{context} rewrite_trace then tracing msg else ()
*}
ML {*
exception REWRITER of string
(* Attempts to use rrule to rewrite ctrm. Match left side of rrule with ctrm.
If successful, return the instantiated theorem with ctrm as left side.
Otherwise, return NONE.
*)
fun rewrite_once_with_rule rrule ctrm =
let
val {thm, name, lhs, elhs, extra, fo, perm = _} = rrule
val insts =
if fo then Thm.first_order_match (elhs, ctrm)
else Thm.match (elhs, ctrm)
val inst_thm =
Thm.instantiate insts (Thm.rename_boundvars lhs (term_of ctrm) thm)
val inst_prop = Thm.prop_of thm
val conditional = (Logic.count_prems inst_prop <> 0)
in
if extra then raise REWRITER "Cannot handle rules with extra variables"
else if conditional then raise REWRITER "Cannot handle rules with conditions"
else cond_tracing ("Applying rule " ^ name ^ " to write " ^ (cterm_to_string ctrm) ^
" as " ^ (cterm_to_string (Thm.rhs_of inst_thm)));
SOME inst_thm
end handle Pattern.MATCH => NONE
(* Use rrules to rewrite ctrm. Match left side of each rrule with ctrm.
Return the list of instantiated theorems. At most one rrule is used in any
rewrite.
*)
fun rewrite_once_with_rules rrules ctrm =
case rrules of
[] => []
| rrule :: rrules' =>
let
val opt = rewrite_once_with_rule rrule ctrm
val rest = rewrite_once_with_rules rrules' ctrm
in
case opt of NONE => rest
| SOME thm => thm :: rest
end
(* Apply one of rrules once to a proper subterm of ctrm. Return the list of
rewriting theorems, with ctrm as left sides.
*)
fun rewrite_subterm_once rrules ctrm =
case term_of ctrm of
_ $ _ => let
val (ct1, ct2) = Thm.dest_comb ctrm
val ct1_rewrites =
rewrite_all_once rrules ct1
|> map (fn thm1 => Thm.combination thm1 (Thm.reflexive ct2))
val ct2_rewrites =
rewrite_all_once rrules ct2
|> map (fn thm2 => Thm.combination (Thm.reflexive ct1) thm2)
in ct1_rewrites @ ct2_rewrites end
| _ => []
(* Apply one of rrules once to either ctrm or a proper subterm of ctrm. Returns
the list of rewriting theorems, with ctrm as left sides.
*)
and rewrite_all_once rrules ctrm =
rewrite_once_with_rules rrules ctrm @ rewrite_subterm_once rrules ctrm
*}
ML {*
(* Data structures used:
Rewrite_Heap is a heap of (score, (ctrm, thm)), where score is "goodness" of
result, the smaller the better. ctrm is a rewrite of original term, with thm
being the rewriting theorem, with original term on the left and ctrm on the right.
*)
val rewrite_heap_ord =
prod_ord int_ord (prod_ord (Term_Ord.term_ord o pairself term_of) (fn _ => EQUAL))
structure Rewrite_Heap = Heap
(
type elem = int * (cterm * thm);
val ord = rewrite_heap_ord
);
(* Maps rewrites of original term to the corresponding rewriting theorems. Used to
avoid adding duplicates.
*)
structure CTerm_Table = Table
(
type key = cterm;
val ord = Term_Ord.term_ord o pairself term_of
)
*}
ML {*
(* goal_fun is of type cterm -> int which specifies the "goodness" of a result.
The returned value must be \ 0. Smaller score means better result. The
function rewrite will return immediately if it found a term with score 0.
*)
fun rewrite_aux rrules goal_fun max_steps max_score ctrm =
let
(* Keep track of:
found - table of existing rewrites.
to_test - heap of rewrites to be further simplified.
best_found - rewrite with best score.
*)
val found = Unsynchronized.ref (CTerm_Table.empty
|> CTerm_Table.update (ctrm, Thm.reflexive ctrm))
val to_test = Unsynchronized.ref (Rewrite_Heap.empty
|> Rewrite_Heap.insert (goal_fun ctrm, (ctrm, Thm.reflexive ctrm)))
val best_found = Unsynchronized.ref (Rewrite_Heap.min (!to_test))
(* Adding a new result to the table of existing results, and heap of terms
to simplify further.
*)
fun add_new_found thm =
let
val rhs = Thm.rhs_of thm
val score = goal_fun rhs
val present = case CTerm_Table.lookup (!found) rhs of NONE => false | _ => true
in if (present orelse score > max_score) then () else
let val _ = cond_tracing (
"Adding: " ^ cterm_to_string rhs ^ " (" ^ string_of_int score ^ ")") in
found := CTerm_Table.update (rhs, thm) (!found);
to_test := Rewrite_Heap.insert ((goal_fun rhs), (rhs, thm)) (!to_test) end
end
fun add_new_founds thms =
case thms of
[] => ()
| thm :: thms' =>
let val _ = add_new_found thm in add_new_founds thms' end
(* Main loop. Stop when the number of rewrites reached limit, or if
the heap is empty (no more rewrites), or if a solution with score 0
is found.
*)
fun process_queue limit =
let val (best_val, (_, best_thm)) = !best_found in
if limit = 0 orelse
Rewrite_Heap.is_empty (!to_test) orelse
best_val = 0 then best_thm
else let
val cur_min = Rewrite_Heap.min (!to_test)
val (score, (rhs, ori_thm)) = cur_min
val _ = if rewrite_heap_ord (cur_min, (!best_found)) = LESS then
best_found := (score, (rhs, ori_thm)) else ()
val _ = to_test := Rewrite_Heap.delete_min (!to_test)
val cur_found = rewrite_all_once rrules rhs
|> map (fn thm => Thm.transitive ori_thm thm)
val _ = add_new_founds cur_found
in
process_queue (limit - 1)
end
end
in
process_queue max_steps
end
fun rewrite thms goal_fun max_steps max_score ctrm =
let
val rewrite_conv = rewrite thms goal_fun max_steps max_score
in
case (term_of ctrm) of
Const ("HOL.Trueprop", _) $ _ =>
(Conv.arg_conv rewrite_conv) ctrm
| Const ("HOL.eq", _) $ _ $ _ =>
(Conv.binop_conv rewrite_conv) ctrm
| _ =>
let
val rrules = Raw_Simplifier.mk_rrules @{context} thms
in
rewrite_aux rrules goal_fun max_steps max_score ctrm
end
end
*}
ML {*
(* Try to find "shortest" result, based on size of term. *)
fun rewrite_to_min thms =
rewrite thms (Term.size_of_term o term_of)
(* Try to rewrite to goal (of type term). *)
fun rewrite_to_goal thms goal =
rewrite thms (fn ctrm' =>
if (term_of ctrm') aconv goal then 0 else ctrm' |> term_of |> Term.size_of_term)
*}
ML {*
(* Rewrite subgoals. *)
fun rewrite_goals_min thms max_steps max_score th =
Conv.fconv_rule (Conv.prems_conv ~1 (rewrite_to_min thms max_steps max_score)) th
fun rewrite_goals_tac thms max_steps max_score =
PRIMITIVE (rewrite_goals_min thms max_steps max_score)
*}
(* Example: a theory of permutations. *)
type_synonym prm = "(nat \ nat) list"
consts perm :: "prm \ 'a \ 'a" ("_ \ _" [80,80] 80)
overloading
perm_nat \ "perm :: prm \ nat \ nat"
perm_prod \ "perm :: prm \ ('a \ 'b) \ ('a \ 'b)"
perm_list \ "perm :: prm \ 'a list \ 'a list"
begin
fun swap::"nat \ nat \ nat \ nat"
where
"swap a b c = (if c=a then b else (if c=b then a else c))"
primrec perm_nat
where
"perm_nat [] c = c"
| "perm_nat (ab # pi) c = swap (fst ab) (snd ab) (perm_nat pi c)"
fun perm_prod
where
"perm_prod pi (x, y) = (pi \ x, pi \ y)"
primrec perm_list
where
"perm_list pi [] = []"
| "perm_list pi (x#xs) = (pi \ x)#(perm_list pi xs)"
end
lemma perm_append[simp]:
fixes c::"nat" and pi1 pi2 ::"prm"
shows "((pi1 @ pi2) \ c) = (pi1 \ (pi2 \ c))"
by (induct pi1) (auto)
lemma perm_bij[simp]:
fixes c d::"nat" and pi::"prm"
shows "(pi \ c = pi \ d) = (c = d)"
by (induct pi) (auto)
lemma perm_rev[simp]:
fixes c::"nat" and pi::"prm"
shows "pi \ ((rev pi) \ c) = c"
by (induct pi arbitrary: c) (auto)
lemma perm_compose:
fixes c::"nat" and pi1 pi2 ::"prm"
shows "pi1 \ (pi2 \ c) = (pi1 \ pi2) \ (pi1 \ c)"
by (induct pi2) (auto)
lemma
fixes c d :: "nat" and pi1 pi2 :: "prm"
shows "pi1 \ (c, pi2 \ ((rev pi1) \ d)) = (pi1 \ c, (pi1 \ pi2) \ d)"
apply (simp)
apply (rule trans)
apply (rule perm_compose)
apply (simp)
done
ML {*
val perm_thms =
[@{thm perm_prod.simps},
@{thm perm_append}, @{thm perm_bij}, @{thm perm_compose}, @{thm perm_rev}]
@ @{thms perm_list.simps} @ @{thms perm_nat.simps}
*}
lemma
fixes c d :: "nat" and pi1 pi2 :: "prm"
shows "pi1 \ (c, pi2 \ ((rev pi1) \ d)) = (pi1 \ c, (pi1 \ pi2) \ d)"
apply ( tactic {* rewrite_goals_tac perm_thms 11 30 *} )
apply (auto)
done
(* Example: ring operations. *)
ML {*
val ring_thms = @{thms add_ac} @ @{thms mult_ac} @
[@{thm distrib}, @{thm distrib} RS @{thm sym}]
*}
lemma
fixes a b c d e f :: "nat"
shows "(a + b) + (c * (d + e)) + f = (f + a) + c * d + c * e + b"
apply ( tactic {* rewrite_goals_tac ring_thms 15 20 *} )
apply (auto)
done