theory State_Monad imports "~~/src/HOL/Library/Monad_Syntax" "$AFP/Applicative_Lifting/Applicative" begin datatype ('s, 'a) state = State (run_state: "'s \ ('a \ 's)") lemma set_state_iff: "x \ set_state m \ (\s s'. run_state m s = (x, s'))" apply (cases m) apply auto apply (metis fsts.cases prod.collapse) by (metis insert_iff prod_set_simps(1)) lemma pred_stateI[intro]: assumes "\a s s'. run_state m s = (a, s') \ P a" shows "pred_state P m" apply (subst state.pred_set) apply rule apply (subst (asm) set_state_iff) apply (erule exE)+ apply (rule assms) apply assumption done lemma pred_stateD[dest]: assumes "pred_state P m" "run_state m s = (a, s')" shows "P a" using assms apply (cases m) apply auto apply (subst (asm) pred_prod_beta) apply auto apply (erule allE[where x = s]) apply auto done lemma pred_state_run_state: "pred_state P m \ P (fst (run_state m s))" apply (drule pred_stateD[where s = s]) apply auto apply (subst surjective_pairing) apply (rule refl) done definition state_io_rel :: "('s \ 's \ bool) \ ('s, 'a) state \ bool" where "state_io_rel P m = (\s. P s (snd (run_state m s)))" lemma state_io_relI[intro]: assumes "\a s s'. run_state m s = (a, s') \ P s s'" shows "state_io_rel P m" using assms unfolding state_io_rel_def by (metis prod.collapse) lemma state_io_relD[dest]: assumes "state_io_rel P m" "run_state m s = (a, s')" shows "P s s'" using assms unfolding state_io_rel_def by (metis snd_conv) lemma state_io_rel_mono[mono]: "P \ Q \ state_io_rel P \ state_io_rel Q" by blast lemma state_ext: assumes "\s. run_state m s = run_state n s" shows "m = n" using assms by (cases m; cases n) auto context begin qualified definition return :: "'a \ ('s, 'a) state" where "return a = State (Pair a)" qualified definition ap :: "('s, 'a \ 'b) state \ ('s, 'a) state \ ('s, 'b) state" where "ap f x = State (\s. case run_state f s of (g, s') \ case run_state x s' of (y, s'') \ (g y, s''))" qualified definition bind :: "('s, 'a) state \ ('a \ ('s, 'b) state) \ ('s, 'b) state" where "bind x f = State (\s. case run_state x s of (a, s') \ run_state (f a) s')" adhoc_overloading Monad_Syntax.bind bind lemma bind_left_identity[simp]: "bind (return a) f = f a" unfolding return_def bind_def by simp lemma bind_right_identity[simp]: "bind m return = m" unfolding return_def bind_def by simp lemma bind_assoc[simp]: "bind (bind m f) g = bind m (\x. bind (f x) g)" unfolding bind_def by (auto split: prod.splits) lemma bind_predI[intro]: assumes "pred_state (\x. pred_state P (f x)) m" shows "pred_state P (bind m f)" apply (rule pred_stateI) unfolding bind_def using assms by (auto split: prod.splits) adhoc_overloading Applicative.ap ap applicative state for pure: return ap: ap unfolding ap_def return_def by (auto split: prod.splits) qualified definition get :: "('s, 's) state" where "get = State (\s. (s, s))" qualified definition set :: "'s \ ('s, unit) state" where "set s' = State (\_. ((), s'))" lemma get_set[simp]: "bind get set = return ()" unfolding bind_def get_def set_def return_def by simp lemma set_set[simp]: "bind (set s) (\_. set s') = set s'" unfolding bind_def set_def by simp fun traverse_list :: "('a \ ('b, 'c) state) \ 'a list \ ('b, 'c list) state" where "traverse_list _ [] = return []" | "traverse_list f (x # xs) = do { x \ f x; xs \ traverse_list f xs; return (x # xs) }" lemma traverse_list_app[simp]: "traverse_list f (xs @ ys) = do { xs \ traverse_list f xs; ys \ traverse_list f ys; return (xs @ ys) }" by (induction xs) auto lemma traverse_comp[simp]: "traverse_list (g \ f) xs = traverse_list g (map f xs)" by (induction xs) auto abbreviation mono_state :: "('s::preorder, 'a) state \ bool" where "mono_state \ state_io_rel (op \)" abbreviation strict_mono_state :: "('s::preorder, 'a) state \ bool" where "strict_mono_state \ state_io_rel (op <)" corollary strict_mono_implies_mono: "strict_mono_state m \ mono_state m" unfolding state_io_rel_def by (simp add: less_imp_le) lemma return_mono[simp, intro]: "mono_state (return x)" unfolding return_def by auto lemma get_mono[simp, intro]: "mono_state get" unfolding get_def by auto lemma put_mono: assumes "\x. s' \ x" shows "mono_state (set s')" using assms unfolding set_def by auto lemma map_mono[intro]: "mono_state m \ mono_state (map_state f m)" by (auto intro!: state_io_relI split: prod.splits simp: map_prod_def state.map_sel) lemma map_strict_mono[intro]: "strict_mono_state m \ strict_mono_state (map_state f m)" by (auto intro!: state_io_relI split: prod.splits simp: map_prod_def state.map_sel) lemma bind_mono_strong: assumes "mono_state m" assumes "\x s s'. run_state m s = (x, s') \ mono_state (f x)" shows "mono_state (bind m f)" apply (rule state_io_relI) unfolding bind_def apply (simp split: prod.splits) apply (rule order_trans) apply (rule state_io_relD[OF assms(1)]) apply assumption apply (rule state_io_relD[OF assms(2)]) apply auto done lemma bind_strict_mono_strong1: assumes "mono_state m" assumes "\x s s'. run_state m s = (x, s') \ strict_mono_state (f x)" shows "strict_mono_state (bind m f)" apply (rule state_io_relI) unfolding bind_def apply (simp split: prod.splits) apply (rule le_less_trans) apply (rule state_io_relD[OF assms(1)]) apply assumption apply (rule state_io_relD[OF assms(2)]) apply auto done lemma bind_strict_mono_strong2: assumes "strict_mono_state m" assumes "\x s s'. run_state m s = (x, s') \ mono_state (f x)" shows "strict_mono_state (bind m f)" apply (rule state_io_relI) unfolding bind_def apply (simp split: prod.splits) apply (rule less_le_trans) apply (rule state_io_relD[OF assms(1)]) apply assumption apply (rule state_io_relD[OF assms(2)]) apply auto done qualified definition update :: "('s \ 's) \ ('s, unit) state" where "update f = bind get (set \ f)" lemma update_id[simp]: "update (\x. x) = return ()" unfolding update_def return_def get_def set_def bind_def by auto lemma update_comp[simp]: "bind (update f) (\_. update g) = update (g \ f)" unfolding update_def return_def get_def set_def bind_def by auto lemma update_mono: assumes "\x. x \ f x" shows "mono_state (update f)" using assms unfolding update_def get_def set_def bind_def by (auto intro!: state_io_relI) lemma update_strict_mono: assumes "\x. x < f x" shows "strict_mono_state (update f)" using assms unfolding update_def get_def set_def bind_def by (auto intro!: state_io_relI) end end