import%20marimo%0A%0A__generated_with%20%3D%20%220.19.8%22%0Aapp%20%3D%20marimo.App()%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%20Discrete%20time%20wave%20equation%20PINN%20solver%0A%0A%20%20%20%20This%20notebook%20demonstrates%20how%20to%20use%20Jaxfun%20to%20solve%20the%20discrete%20time%20wave%20equation%20using%20Physics-Informed%20Neural%20Networks%20(PINNs).%20The%20wave%20equation%20is%20a%20second-order%20partial%20differential%20equation%20that%20describes%20the%20propagation%20of%20waves%2C%20such%20as%20sound%20or%20light%20waves%2C%20through%20a%20medium%0A%0A%20%20%20%20%24%24%0A%20%20%20%20%5Cbegin%7Balign%7D%0A%20%20%20%20%5Cfrac%7B%5Cpartial%5E2%20u%7D%7B%5Cpartial%20t%5E2%7D%20%26%3D%20c%5E2%20%5Cnabla%5E2%20u%2C%20%5Cquad%20x%2C%20t%20%5Cin%20%5B0%2C%20%5Cpi%5D%20%5Ctimes%20%5B0%2C%20T%5D%20%5C%5C%0A%20%20%20%20u(x%2C%200)%20%26%3D%20%5Csin(x)%2C%20%5Cquad%20x%20%5Cin%20%5B0%2C%20%5Cpi%5D%20%5C%5C%0A%20%20%20%20%5Cfrac%7B%5Cpartial%20u%7D%7B%5Cpartial%20t%7D(x%2C%200)%20%26%3D%20c%5Csin(x)%2C%20%5Cquad%20x%20%5Cin%20%5B0%2C%20%5Cpi%5D%0A%20%20%20%20%5Cend%7Balign%7D%0A%20%20%20%20%24%24%0A%0A%20%20%20%20Here%20%24u(x%2C%20t)%24%20is%20the%20solution%20we%20want%20to%20approximate%2C%20%24c%24%20is%20the%20wave%20speed%2C%20and%20%24%5Cnabla%5E2%24%20is%20the%20Laplacian%20operator.%20The%20initial%20condition%20is%20derived%20from%20an%20exact%20solution%20%24u(x%2C%20t)%20%3D%20%5Csin(x)(%5Csin(ct)%20%2B%20%5Ccos(ct))%24.%20For%20our%20example%20in%20one%20spatial%20dimension%2C%20the%20equation%20simplifies%20to%3A%0A%0A%20%20%20%20%24%24%0A%20%20%20%20%5Cfrac%7B%5Cpartial%5E2%20u%7D%7B%5Cpartial%20t%5E2%7D%20%3D%20c%5E2%20%5Cfrac%7B%5Cpartial%5E2%20u%7D%7B%5Cpartial%20x%5E2%7D%0A%20%20%20%20%24%24%0A%0A%20%20%20%20We%20will%20discretize%20the%20time%20domain%20into%20%24N%24%20discrete%20steps%20with%20a%20time%20step%20size%20of%20%24%5CDelta%20t%20%3D%20%5Cfrac%7BT%7D%7BN%7D%24.%20We%20will%20use%20Jaxfun%20to%20define%20a%20neural%20network%20for%20the%20solutions%20at%20each%20time%20step%0A%0A%20%20%20%20%24%24%0A%20%20%20%20u_%7B%5Ctheta%7D%5En(x)%20%3D%20u(x%2C%20n%20%5CDelta%20t)%2C%20%5Cquad%20n%20%3D%200%2C%201%2C%20%5Cldots%2C%20N%0A%20%20%20%20%24%24%0A%0A%20%20%20%20and%20discretize%20the%20wave%20equation%20in%20time%20with%20a%20second%20order%20finite%20difference%20scheme%20as%20follows%3A%0A%0A%20%20%20%20%24%24%0A%20%20%20%20u_%7B%5Ctheta%7D%5E%7Bn%2B1%7D(x)%20%3D%202u_%7B%5Ctheta%7D%5En(x)%20-%20u_%7B%5Ctheta%7D%5E%7Bn-1%7D(x)%20%2B%20c%5E2%20%5CDelta%20t%5E2%20%5Cfrac%7B%5Cpartial%5E2%20u_%7B%5Ctheta%7D%5En%7D%7B%5Cpartial%20x%5E2%7D%0A%20%20%20%20%24%24%0A%0A%20%20%20%20We%20will%20train%20the%20neural%20networks%20to%20minimize%20the%20residual%20of%20the%20discretized%20wave%20equation%20at%20each%20time%20step%2C%20along%20with%20the%20initial%20conditions.%20For%20the%20initial%20Dirichlet%20condition%2C%20we%20will%20use%20a%20mean%20squared%20error%20loss%20between%20the%20neural%20network%20output%20and%20the%20initial%20condition.%20Using%20%24M%24%20collocation%20points%20%24x_i%24%20in%20the%20spatial%20domain%2C%20the%20loss%20function%20for%20the%20initial%20condition%20is%20defined%20as%3A%0A%0A%20%20%20%20%24%24%0A%20%20%20%20%5Cmathcal%7BL%7D_%7B%5Ctext%7BIC%7D%7D%20%3D%20%5Csum_%7Bi%3D1%7D%5E%7BM%7D%20%5Cleft(%20%5Comega_i(%20u_%7B%5Ctheta%7D%5E0(x_i)%20-%20%5Csin(x_i))%20%5Cright)%5E2%2C%0A%20%20%20%20%24%24%0A%0A%20%20%20%20where%20%24%5Comega_i%24%20are%20weights%20that%20can%20be%20used%20to%20emphasize%20certain%20points%20in%20the%20domain.%20The%20weights%20are%20by%20default%20set%20to%20%241%2FM%24%2C%20but%20choosing%20different%20weights%20can%20help%20improve%20convergence%20in%20some%20cases.%20By%20choosing%20Legendre%20points%20as%20collocation%20points%20we%20can%20use%20the%20corresponding%20quadrature%20weights%20such%20that%20the%20loss%20more%20closely%20approximates%20an%20integral%20over%20the%20domain.%0A%0A%20%20%20%20We%20will%20use%20the%20second%20initial%20condition%20to%20compute%20%24u_%7B%5Ctheta%7D%5E1(x)%24%20directly%20from%20%24u_%7B%5Ctheta%7D%5E0(x)%24%20using%20the%20discrete%20wave%20equation%20at%20time%20step%20%24n%3D0%24%3A%0A%0A%20%20%20%20%24%24%0A%20%20%20%20u_%7B%5Ctheta%7D%5E%7B1%7D(x)%20%3D%202u_%7B%5Ctheta%7D%5E0(x)%20-%20u_%7B%5Ctheta%7D%5E%7B-1%7D(x)%20%2B%20c%5E2%20%5CDelta%20t%5E2%20%5Cfrac%7B%5Cpartial%5E2%20u_%7B%5Ctheta%7D%5E0%7D%7B%5Cpartial%20x%5E2%7D%0A%20%20%20%20%24%24%0A%0A%20%20%20%20Here%20%24u_%7B%5Ctheta%7D%5E%7B-1%7D(x)%24%20can%20be%20computed%20from%20the%20initial%20velocity%20condition%20as%3A%0A%0A%20%20%20%20%24%24%0A%20%20%20%20%5Cfrac%7Bu_%7B%5Ctheta%7D%5E%7B1%7D%20-%20u_%7B%5Ctheta%7D%5E%7B-1%7D%7D%7B2%20%5CDelta%20t%7D%20%5Capprox%20c%5Csin(x)%0A%20%20%20%20%24%24%0A%0A%20%20%20%20Thus%2C%20we%20can%20express%20%24u_%7B%5Ctheta%7D%5E%7B-1%7D(x)%24%20as%3A%0A%0A%20%20%20%20%24%24%0A%20%20%20%20u_%7B%5Ctheta%7D%5E%7B-1%7D(x)%20%3D%20u_%7B%5Ctheta%7D%5E%7B1%7D(x)%20-%202%20%5CDelta%20t%20c%20%5Csin(x)%0A%20%20%20%20%24%24%0A%0A%20%20%20%20and%20the%20equation%20to%20solve%20for%20%24u_%7B%5Ctheta%7D%5E%7B1%7D(x)%24%20becomes%3A%0A%0A%20%20%20%20%24%24%0A%20%20%20%20u_%7B%5Ctheta%7D%5E%7B1%7D(x)%20%3D%20u_%7B%5Ctheta%7D%5E0(x)%20%2B%20%5CDelta%20t%20c%20%5Csin(x)%20%2B%20%5Cfrac%7Bc%5E2%20%5CDelta%20t%5E2%7D%7B2%7D%20%5Cfrac%7B%5Cpartial%5E2%20u_%7B%5Ctheta%7D%5E0%7D%7B%5Cpartial%20x%5E2%7D%0A%20%20%20%20%24%24%0A%0A%20%20%20%20The%20loss%20function%20for%20this%20equation%20is%20then%20defined%20as%3A%0A%0A%20%20%20%20%24%24%0A%20%20%20%20%5Cmathcal%7BL%7D_%7B%5Ctext%7BIC2%7D%7D%20%3D%20%5Csum_%7Bi%3D1%7D%5E%7BM%7D%20%5Cleft(%20%5Comega_i%20%5Cleft(%20u_%7B%5Ctheta%7D%5E%7B1%7D(x_i)%20-%20%5Cleft(%20u_%7B%5Ctheta%7D%5E0(x_i)%20%2B%20%5CDelta%20t%20c%20%5Csin(x_i)%20%2B%20%5Cfrac%7Bc%5E2%20%5CDelta%20t%5E2%7D%7B2%7D%20%5Cfrac%7B%5Cpartial%5E2%20u_%7B%5Ctheta%7D%5E0%7D%7B%5Cpartial%20x%5E2%7D%20%5Cbigg%7C_%7Bx%3Dx_i%7D%20%5Cright)%20%5Cright)%20%5Cright)%5E2%0A%20%20%20%20%24%24%0A%0A%20%20%20%20The%20boundary%20condition%20at%20each%20time%20step%20is%20enforced%20by%20adding%20a%20penalty%20term%20to%20the%20loss%20function%3A%0A%0A%20%20%20%20%24%24%0A%20%20%20%20%5Cmathcal%7BL%7D%5En_%7B%5Ctext%7BBC%7D%7D%20%3D%20%5Cleft(%20%5Comega_i%20(u_%7B%5Ctheta%7D%5En(0)%5E2%20%2B%20u_%7B%5Ctheta%7D%5En(%5Cpi)%5E2)%20%5Cright)%2C%20%5Cquad%20n%20%3D%200%2C%201%2C%20%5Cldots%2C%20N%0A%20%20%20%20%24%24%0A%0A%20%20%20%20For%20all%20time%20steps%20%24n%20%3D%201%2C%202%2C%20%5Cldots%2C%20N-1%24%2C%20the%20residual%20loss%20of%20the%20discretized%20wave%20equation%20is%20defined%20as%3A%0A%0A%20%20%20%20%24%24%0A%20%20%20%20%5Cmathcal%7BL%7D_%7B%5Ctext%7BPDE%7D%7D%5En%20%3D%20%5Csum_%7Bi%3D1%7D%5E%7BM%7D%20%5Cleft(%20%5Comega_i%20%5Cleft(%20u_%7B%5Ctheta%7D%5E%7Bn%2B1%7D(x_i)%20-%20%5Cleft(%202u_%7B%5Ctheta%7D%5En(x_i)%20-%20u_%7B%5Ctheta%7D%5E%7Bn-1%7D(x_i)%20%2B%20c%5E2%20%5CDelta%20t%5E2%20%5Cfrac%7B%5Cpartial%5E2%20u_%7B%5Ctheta%7D%5En%7D%7B%5Cpartial%20x%5E2%7D%20%5Cbigg%7C_%7Bx%3Dx_i%7D%20%5Cright)%20%5Cright)%20%5Cright)%5E2%0A%20%20%20%20%24%24%0A%0A%20%20%20%20Note%20that%20the%20unknown%20neural%20network%20%24u%5E%7Bn%2B1%7D_%7B%5Ctheta%7D(x)%24%20appears%20in%20the%20loss%20without%20any%20derivatives%2C%20which%20is%20very%20convenient%20for%20fast%20training.%20The%20known%20terms%20on%20the%20right-hand%20side%20can%20be%20evaluated%20once%20for%20each%20time%20step%20and%20reused%20for%20all%20training%20iterations%20at%20that%20time.%0A%0A%20%20%20%20We%20will%20start%20by%20creating%20a%20collocation%20of%20points%20in%20the%20spatial%20domain%20%24%5B0%2C%20%5Cpi%5D%24%20to%20evaluate%20the%20losses.%20Then%2C%20we%20will%20define%20a%20neural%20network%20architecture%20for%20the%20solution%20at%20each%20time%20step%20and%20set%20up%20the%20training%20loop%20to%20minimize%20the%20total%20loss.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20import%20jax%0A%0A%20%20%20%20jax.config.update(%22jax_enable_x64%22%2C%20True)%0A%20%20%20%20import%20jax.numpy%20as%20jnp%0A%0A%20%20%20%20from%20jaxfun.pinns.mesh%20import%20Line%0A%0A%20%20%20%20M%20%3D%20256%0A%20%20%20%20mesh%20%3D%20Line(0%2C%20jnp.pi)%0A%20%20%20%20return%20M%2C%20jnp%2C%20mesh%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Here%20%60mesh%60%20is%20a%20one-dimensional%20spatial%20domain%20defined%20from%200%20to%20%CF%80.%20We%20can%20use%20this%20domain%20to%20create%20grid%20points%20for%20our%20computations%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(M%2C%20mesh)%3A%0A%20%20%20%20xa%20%3D%20mesh.get_points(M%2C%20domain%3D%22all%22%2C%20kind%3D%22legendre%22)%0A%20%20%20%20xi%20%3D%20mesh.get_points(M%2C%20domain%3D%22inside%22%2C%20kind%3D%22legendre%22)%0A%20%20%20%20xb%20%3D%20mesh.get_points(M%2C%20domain%3D%22boundary%22%2C%20kind%3D%22legendre%22)%0A%20%20%20%20wi%20%3D%20mesh.get_weights(M%2C%20domain%3D%22inside%22%2C%20kind%3D%22legendre%22)%0A%20%20%20%20return%20wi%2C%20xb%2C%20xi%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Here%20%60xa%60%20contains%20all%20the%20M%20collocation%20points%20in%20the%20spatial%20domain%20%5B0%2C%20%CF%80%5D%2C%20including%20the%20boundary%20points%200%20and%20%CF%80.%20The%20interior%20and%20boundary%20points%20are%20represented%20by%20%60xi%60%20and%20%60xb%60%2C%20respectively.%0A%0A%20%20%20%20Next%20we%20define%20the%20neural%20network%20architecture%20for%20the%20solution%20at%20each%20time%20step.%20We%20will%20use%20a%20simple%20feedforward%20neural%20network%20with%20one%20hidden%20layer%20of%20size%2064%20and%20the%20softmax%20activation%20function.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20import%20sympy%20as%20sp%0A%20%20%20%20from%20flax%20import%20nnx%0A%0A%20%20%20%20from%20jaxfun.pinns.nnspaces%20import%20MLPSpace%0A%0A%20%20%20%20V%20%3D%20MLPSpace(64%2C%20dims%3D1%2C%20transient%3DFalse%2C%20rank%3D0%2C%20name%3D%22V%22%2C%20act_fun%3Dnnx.softmax)%0A%20%20%20%20x%20%3D%20V.system.x%0A%20%20%20%20t%20%3D%20sp.Symbol(%22t%22%2C%20real%3DTrue)%0A%20%20%20%20return%20V%2C%20nnx%2C%20sp%2C%20t%2C%20x%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20neural%20network%20is%20defined%20using%20the%20%60FlaxFunction%60%20class%20from%20Jaxfun.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(V%2C%20nnx)%3A%0A%20%20%20%20from%20jaxfun.pinns.module%20import%20FlaxFunction%0A%0A%20%20%20%20u%20%3D%20FlaxFunction(V%2C%20name%3D%22u%22%2C%20rngs%3Dnnx.Rngs(12))%0A%20%20%20%20return%20(u%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20%60FlaxFunction%60%20%60u%60%20represents%20the%20neural%20network%20and%20contains%20an%20%60nnx.Module%60.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(u)%3A%0A%20%20%20%20u%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20%60FlaxFunction%60%20is%20a%20subclass%20of%20a%20sympy%20%60Function%60%20that%20wraps%20a%20Flax%20neural%20network%20module.%20We%20can%20treat%20it%20as%20any%20other%20sympy%20function%20in%20our%20equations%2C%20and%20take%20the%20derivative%20with%20respect%20to%20its%20input%20variable%20%60x%60.%20It%20can%20also%20be%20evaluated%20at%20any%20given%20points%2C%20as%20%60u(xa)%60.%20For%20any%20terms%20involving%20derivatives%20of%20%60u%60%2C%20please%20use%20%60evaluate%60%20from%20%60jaxfun.pinns.loss%60%20for%20explicit%20evaluations%20of%20the%20known%20function.%0A%0A%20%20%20%20We%20start%20by%20computing%20%24u%5E0_%7B%5Ctheta%7D%24%20using%20%24%5Cmathcal%7BL%7D_%7BIC%7D%24%20and%20a%20trainer.%20The%20loss%20function%20is%20created%20using%20the%20%60Loss%60%20class%20from%20Jaxfun.%20The%20%60Loss%60%20class%20takes%20a%20tuple%20of%20problems%20to%20solve%20and%20constructs%20the%20corresponding%20loss%20functions.%20Here%20we%20feed%20it%20the%20boundary%20condition%20%24%5Cmathcal%7BL%7D_%7BBC%7D%5E0%24%20and%20the%20initial%20condition%20problem%20%24u-%5Csin(x)%20%3D%200%24%2C%20along%20with%20collocation%20points%20%60xi%60.%20The%20tuples%20fed%20to%20%60Loss%60%20are%20of%20the%20form%20%60(equation%2C%20points%2C%20target%2C%20weights)%60%2C%20where%20%60equation%60%20is%20the%20equation%20to%20solve%2C%20%60points%60%20are%20the%20collocation%20points%2C%20%60target%60%20is%20the%20target%20value%20for%20the%20equation%20(usually%20zero)%2C%20and%20%60weights%60%20are%20the%20weights%20for%20each%20term%20in%20the%20loss%20function.%20The%20default%20target%20is%200%2C%20and%20the%20default%20weights%20are%201%20over%20the%20number%20of%20collocation%20points.%20Here%20we%20choose%20the%20weights%20to%20be%20the%20quadrature%20weights%20corresponding%20to%20the%20Legendre%20points.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(sp%2C%20u%2C%20wi%2C%20x%2C%20xb%2C%20xi)%3A%0A%20%20%20%20from%20jaxfun.pinns.loss%20import%20Loss%0A%20%20%20%20from%20jaxfun.pinns.optimizer%20import%20DiscreteTimeTrainer%20as%20Trainer%0A%0A%20%20%20%20loss_fn%20%3D%20Loss((u%2C%20xb)%2C%20(u%20-%20sp.sin(x)%2C%20xi%2C%200%2C%20wi))%0A%20%20%20%20trainer%20%3D%20Trainer(loss_fn)%0A%20%20%20%20return%20loss_fn%2C%20trainer%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%60loss_fn%60%20can%20now%20compute%20the%20current%20loss%20using%20its%20initialized%20weights.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(loss_fn%2C%20u)%3A%0A%20%20%20%20loss_fn(u.module)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20solve%20using%20first%20order%20Adam%20optimizer%20with%20learning%20rate%200.001%20for%205000%20iterations%20and%20then%20switch%20to%20second%20order%20L-BFGS%20optimizer%20for%20a%20few%20more%20iterations%20to%20get%20a%20sufficiently%20low%20loss.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(trainer%2C%20u)%3A%0A%20%20%20%20from%20jaxfun.pinns.optimizer%20import%20adam%2C%20lbfgs%0A%0A%20%20%20%20opt_adam%20%3D%20adam(u)%0A%20%20%20%20opt_lbfgs%20%3D%20lbfgs(u%2C%20memory_size%3D10)%0A%0A%20%20%20%20trainer.train(opt_adam%2C%205000%2C%20epoch_print%3D1000)%0A%20%20%20%20trainer.train(%0A%20%20%20%20%20%20%20%20opt_lbfgs%2C%205000%2C%20epoch_print%3D1000%2C%20abs_limit_loss%3D1e-11%2C%20print_final_loss%3DTrue%0A%20%20%20%20)%0A%20%20%20%20trainer.step(u.module)%0A%20%20%20%20return%20(opt_lbfgs%2C)%0A%0A%0A%40app.cell%0Adef%20_(jnp%2C%20mesh%2C%20sp%2C%20t%2C%20u%2C%20x)%3A%0A%20%20%20%20from%20jaxfun.utils%20import%20lambdify%0A%0A%20%20%20%20c%20%3D%201.0%0A%20%20%20%20x0%20%3D%20mesh.get_points(1000%2C%20domain%3D%22all%22%2C%20kind%3D%22uniform%22)%0A%20%20%20%20ue%20%3D%20sp.sin(x)%20*%20(sp.sin(c%20*%20t)%20%2B%20sp.cos(c%20*%20t))%0A%20%20%20%20uej%20%3D%20lambdify(x%2C%20ue.subs(t%2C%200))(x0%5B%3A%2C%200%5D)%0A%20%20%20%20print(jnp.linalg.norm(uej%20-%20u(x0))%20%2F%20jnp.sqrt(uej.shape%5B0%5D))%0A%20%20%20%20return%20c%2C%20lambdify%2C%20ue%2C%20uej%2C%20x0%0A%0A%0A%40app.cell%0Adef%20_(u%2C%20uej%2C%20x0)%3A%0A%20%20%20%20import%20matplotlib.pyplot%20as%20plt%0A%0A%20%20%20%20plt.plot(u(x0)%20-%20uej)%0A%20%20%20%20return%20(plt%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20solution%20is%20now%20converged%20for%20%24u%5E%7B0%7D_%7B%5Ctheta%7D%24%20and%20the%20last%20step%20%60trainer.step%60%20prepares%20the%20weights%20for%20the%20next%20time%20step.%20The%20trainer%20keeps%20track%20of%20the%20weights%20(the%20flax%20%60State%60)%20for%20each%20time%20step.%0A%0A%20%20%20%20We%20now%20proceed%20to%20compute%20%24u%5E%7B1%7D_%7B%5Ctheta%7D%24%20using%20%24%5Cmathcal%7BL%7D_%7BIC2%7D%24%20and%20the%20same%20trainer.%20All%20we%20need%20to%20do%20is%20to%20recompute%20the%20target%20of%20the%20second%20problem%20given%20to%20the%20%60Loss%60%20class.%20The%20target%20is%20now%20given%20by%20the%20right-hand%20side%20of%20the%20equation%20for%20%24u%5E%7B1%7D_%7B%5Ctheta%7D%24.%20We%20use%20the%20jaxfun%20class%20%60Residual%60%20to%20compute%20%24c%5E2%20%5CDelta%20t%5E2%20%5Cfrac%7B%5Cpartial%5E2%20u%5E%7B0%7D_%7B%5Ctheta%7D%7D%7B%5Cpartial%20x%5E2%7D%24%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(c%2C%20jnp%2C%20loss_fn%2C%20opt_lbfgs%2C%20trainer%2C%20u%2C%20x%2C%20xi)%3A%0A%20%20%20%20from%20jaxfun.pinns.loss%20import%20Residual%0A%0A%20%20%20%20dt%20%3D%20jnp.pi%20%2F%20100%0A%0A%20%20%20%20unm1_%20%3D%20u(xi)%0A%20%20%20%20target%20%3D%20Residual(c**2%20*%20dt**2%20*%20u.diff(x%2C%202)%2C%20xi)%0A%20%20%20%20loss_fn.residuals%5B1%5D.target%20%3D%20unm1_%20*%20(1%20%2B%20c%20*%20dt)%20%2B%200.5%20*%20target.evaluate(u.module)%0A%20%20%20%20trainer.train(%0A%20%20%20%20%20%20%20%20opt_lbfgs%2C%205000%2C%20epoch_print%3D1000%2C%20abs_limit_loss%3D1e-10%2C%20print_final_loss%3DTrue%0A%20%20%20%20)%0A%20%20%20%20trainer.step(u.module)%0A%20%20%20%20return%20dt%2C%20target%2C%20unm1_%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20now%20have%20the%20solution%20for%20%24u%5E0_%7B%5Ctheta%7D%24%20and%20%24u%5E1_%7B%5Ctheta%7D%24%2C%20and%20in%20python%20we%20store%20the%20evaluation%20of%20these%20solutions%20at%20the%20collocation%20points%20%60xi%60%20in%20%60unm1%60%20and%20%60un%60%2C%20respectively.%20Throughout%20%60unm1%60%20corresponds%20to%20%24u%5E%7Bn-1%7D_%7B%5Ctheta%7D(%5Cboldsymbol%7Bx%7D)%24%20and%20%60un%60%20corresponds%20to%20%24u%5E%7Bn%7D_%7B%5Ctheta%7D(%5Cboldsymbol%7Bx%7D)%24%2C%20where%20%24%5Cboldsymbol%7Bx%7D%3D(x_i)_%7Bi%3D0%7D%5E%7BM-3%7D%24.%0A%0A%20%20%20%20All%20that%20remains%20is%20to%20loop%20over%20the%20remaining%20time%20steps%20and%20solve%20for%20%24u%5E%7Bn%2B1%7D_%7B%5Ctheta%7D%24%20using%20the%20discretized%20wave%20equation%20residual%20loss%20%24%5Cmathcal%7BL%7D_%7BPDE%7D%5En%24%20and%20the%20boundary%20condition%20loss%20%24%5Cmathcal%7BL%7D_%7BBC%7D%5En%24.%20We%20implement%20the%20PDE%20loss%20by%20modifying%20the%20second%20argument%20of%20the%20%60Loss%60%20class%2C%20which%20computes%20the%20residual%20%24u%5E%7Bn%2B1%7D_%7B%5Ctheta%7D(x)%20-%20f(x)%24%2C%20where%20%24f(x)%24%20is%20a%20target.%20In%20the%20discrete%20PDE%20%24f(x)%20%3D%202%20u%5En_%7B%5Ctheta%7D%20-%20u%5E%7Bn-1%7D_%7B%5Ctheta%7D%20%2B%20c%5E2%20%5CDelta%20t%5E2%20%5Cfrac%7B%5Cpartial%5E2%20u%5E%7Bn%7D_%7B%5Ctheta%7D%7D%7B%5Cpartial%20x%5E2%7D%24.%20This%20term%20is%20evaluated%20once%20from%20known%20solutions%20and%20placed%20in%20the%20target%20of%20the%20loss%20function.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(%0A%20%20%20%20dt%2C%0A%20%20%20%20jnp%2C%0A%20%20%20%20lambdify%2C%0A%20%20%20%20loss_fn%2C%0A%20%20%20%20mesh%2C%0A%20%20%20%20opt_lbfgs%2C%0A%20%20%20%20t%2C%0A%20%20%20%20target%2C%0A%20%20%20%20trainer%2C%0A%20%20%20%20u%2C%0A%20%20%20%20ue%2C%0A%20%20%20%20unm1_%2C%0A%20%20%20%20x%2C%0A%20%20%20%20xi%2C%0A)%3A%0A%20%20%20%20import%20optax%0A%0A%20%20%20%20unm1%20%3D%20unm1_%0A%20%20%20%20un%20%3D%20u(xi)%0A%20%20%20%20Nsteps%20%3D%2040%0A%20%20%20%20for%20_step%20in%20range(1%2C%20Nsteps)%3A%0A%20%20%20%20%20%20%20%20print(f%22Time%20step%20%7B_step%20%2B%201%7D%2F%7BNsteps%7D%22)%0A%20%20%20%20%20%20%20%20loss_fn.residuals%5B1%5D.target%20%3D%202%20*%20un%20-%20unm1%20%2B%20target.evaluate(u.module)%0A%20%20%20%20%20%20%20%20opt_lbfgs.opt_state%20%3D%20optax.tree.zeros_like(opt_lbfgs.opt_state)%0A%20%20%20%20%20%20%20%20trainer.train(%0A%20%20%20%20%20%20%20%20%20%20%20%20opt_lbfgs%2C%0A%20%20%20%20%20%20%20%20%20%20%20%202000%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20epoch_print%3D500%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20abs_limit_loss%3D1e-10%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20print_final_loss%3DTrue%2C%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20unm1%20%3D%20un%0A%20%20%20%20%20%20%20%20un%20%3D%20u(xi)%0A%20%20%20%20%20%20%20%20trainer.step(u.module)%0A%20%20%20%20x1%20%3D%20mesh.get_points(1000%2C%20domain%3D%22all%22%2C%20kind%3D%22uniform%22)%0A%20%20%20%20uej_1%20%3D%20lambdify(x%2C%20ue.subs(t%2C%20dt%20*%20Nsteps))(x1%5B%3A%2C%200%5D)%0A%20%20%20%20print(%0A%20%20%20%20%20%20%20%20%22Error%20at%20final%20time%20step%3A%22%2C%0A%20%20%20%20%20%20%20%20jnp.linalg.norm(u(x1)%20-%20uej_1)%20%2F%20jnp.sqrt(uej_1.shape%5B0%5D)%2C%0A%20%20%20%20)%0A%20%20%20%20return%20(Nsteps%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Notice%20that%20each%20time%20step%20takes%20just%20a%20few%20iterations%20to%20converge%20since%20we%20start%20with%20a%20very%20good%20initial%20guess%20from%20the%20previous%20time%20step.%20We%20can%20now%20visualize%20the%20final%20solution%20by%20evaluating%20the%20neural%20networks%20at%20chosen%20time%20steps%20on%20a%20grid%20of%20points%20in%20space%20and%20time.%20We%20can%20then%20plot%20the%20solution%20as%20a%20surface%20plot.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(Nsteps%2C%20dt%2C%20lambdify%2C%20plt%2C%20t%2C%20trainer%2C%20u%2C%20ue%2C%20x%2C%20x0)%3A%0A%20%20%20%20results%20%3D%20%5B%5D%0A%20%20%20%20exact%20%3D%20%5B%5D%0A%20%20%20%20for%20_step%20in%20range(Nsteps)%3A%0A%20%20%20%20%20%20%20%20results.append(trainer.evaluate_at_step(u%2C%20x0%2C%20_step)%5B%3A%2C%200%5D)%0A%20%20%20%20%20%20%20%20exact.append(lambdify(x%2C%20ue.subs(t%2C%20_step%20*%20dt))(x0%5B%3A%2C%200%5D))%0A%20%20%20%20%20%20%20%20if%20_step%20%25%2020%20%3D%3D%200%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20plt.plot(x0%2C%20results%5B-1%5D%2C%20label%3Df%22t%3D%7B_step%20*%20dt%3A.2f%7D%22)%0A%20%20%20%20plt.legend(loc%3D%22upper%20right%22)%0A%20%20%20%20plt.show()%0A%20%20%20%20return%20exact%2C%20results%0A%0A%0A%40app.cell%0Adef%20_(Nsteps%2C%20dt%2C%20jnp%2C%20plt%2C%20results)%3A%0A%20%20%20%20uj%20%3D%20jnp.array(results)%0A%20%20%20%20plt.contourf(uj.T%2C%20extent%3D%5B0%2C%20Nsteps%20*%20dt%2C%200%2C%20jnp.pi%5D%2C%20levels%3D100%2C%20cmap%3D%22jet%22)%0A%20%20%20%20return%20(uj%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20A%20heat%20map%20of%20the%20error%20between%20the%20PINN%20solution%20and%20the%20exact%20solution%20shows%20that%20the%20error%20is%20small%20throughout%20the%20domain%2C%20indicating%20that%20the%20PINN%20has%20successfully%20learned%20the%20solution%20to%20the%20discrete%20time%20wave%20equation%20and%20the%20error%20is%20well%20controlled.%20The%20time%20axis%20is%20along%20the%20vertical%20direction%20in%20the%20plots.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(Nsteps%2C%20dt%2C%20exact%2C%20jnp%2C%20plt%2C%20uj)%3A%0A%20%20%20%20plt.contourf(%0A%20%20%20%20%20%20%20%20(jnp.array(exact)%20-%20uj).T%2C%0A%20%20%20%20%20%20%20%20extent%3D%5B0%2C%20Nsteps%20*%20dt%2C%200%2C%20jnp.pi%5D%2C%0A%20%20%20%20%20%20%20%20levels%3D100%2C%0A%20%20%20%20%20%20%20%20cmap%3D%22jet%22%2C%0A%20%20%20%20)%0A%20%20%20%20plt.title(%22Absolute%20Error%22)%0A%20%20%20%20plt.colorbar()%0A%20%20%20%20plt.show()%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_()%3A%0A%20%20%20%20import%20marimo%20as%20mo%0A%0A%20%20%20%20return%20(mo%2C)%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
351669d0a3db153505bd04d4039f911a