import%20marimo%0A%0A__generated_with%20%3D%20%220.19.8%22%0Aapp%20%3D%20marimo.App()%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%20Physics%20informed%20neural%20networks%0A%0A%20%20%20%20In%20this%20notebook%20we%20will%20use%20%5BJaxfun%5D(https%3A%2F%2Fgithub.com%2FspectralDNS%2Fjaxfun)%20together%20with%0A%0A%20%20%20%20*%20%5BJax%5D(https%3A%2F%2Fdocs.jax.dev%2Fen%2Flatest%2Findex.html)%0A%20%20%20%20*%20%5BOptax%5D(https%3A%2F%2Foptax.readthedocs.io%2Fen%2Flatest%2F)%0A%20%20%20%20*%20%5BFlax%5D(https%3A%2F%2Fgithub.com%2Fgoogle%2Fflax)%0A%0A%20%20%20%20in%20order%20to%20solve%20a%20differential%20equation%20using%20the%20least%20squares%20minimization%20formulation%20of%20the%20problem.%20We%20will%20use%20basis%20functions%20based%20on%20regular%20multi%20layer%20perceptrons%20as%20well%20as%20spectral%20expansions%20in%20orthogonal%20polynomials.%0A%0A%20%20%20%20The%20ubiquitous%20linear%20Helmholtz%20equation%20is%20defined%20as%0A%0A%20%20%20%20%24%24%0A%20%20%20%20u%5E%7B''%7D(x)%20%2B%20%5Calpha%20u(x)%20%3D%20f(x)%2C%20%5Cquad%20x%20%5Cin%20(-1%2C%201)%2C%20%5C%2C%20%5Calpha%20%5Cin%20%5Cmathbb%7BR%5E%2B%7D%0A%20%20%20%20%24%24%0A%0A%20%20%20%20and%20in%20this%20notebook%20it%20will%20be%20used%20with%20boundary%20conditions%20%24u(-1)%3Du(1)%3D0%24.%20The%20function%20%24u(x)%24%20represents%20the%20unknown%20solution%20and%20the%20right%20hand%20side%20function%20%24f(x)%24%20is%20continuous%20and%20known.%20We%20can%20define%20a%20residual%20%24%5Cmathcal%7BR%7Du(x)%24%20for%20the%20Helmholtz%20equation%20as%0A%0A%20%20%20%20%24%24%0A%20%20%20%20%5Cmathcal%7BR%7Du(x)%20%3D%20u%5E%7B''%7D(x)%20%2B%20%5Calpha%20u(x)%20-%20f(x)%2C%0A%20%20%20%20%24%24%0A%0A%20%20%20%20which%20should%20be%20zero%20for%20any%20point%20in%20the%20domain.%20In%20this%20notebook%20we%20will%20approximate%20%24u(x)%24%20with%20a%20neural%20network%20%24u_%7B%5Ctheta%7D(x)%24%2C%20where%20%24%5Ctheta%20%5Cin%20%5Cmathbb%7BR%7D%5EM%24%20represents%20the%20unknown%20weights%20of%20the%20network.%20In%20order%20to%20find%20%24u_%7B%5Ctheta%7D(x)%24%20we%20will%20attempt%20to%20force%20the%20residual%20%24%5Cmathcal%7BR%7Du_%7B%5Ctheta%7D(x_i)%24%20to%20zero%20in%20a%20least%20squares%20sense%20for%20some%20%24N%24%20chosen%20training%20points%20%24%5C%7Bx_i%5C%7D_%7Bi%3D0%7D%5E%7BN-1%7D%24.%20To%20this%20end%20the%20least%20squares%20problem%20reads%0A%0A%20%20%20%20%5Cbegin%7Bequation*%7D%0A%20%20%20%20%5Cunderset%7B%5Ctheta%20%5Cin%20%5Cmathbb%7BR%7D%5EM%7D%7B%5Ctext%7Bminimize%7D%7D%5C%2C%20L(%5Ctheta)%3A%3D%5Cfrac%7B1%7D%7BN%7D%5Csum_%7Bi%3D0%7D%5E%7BN-1%7D%20%5Cmathcal%7BR%7Du_%7B%5Ctheta%7D(x_i%3B%20%5Ctheta)%5E2%20%2B%20%5Cfrac%7B1%7D%7B2%7D%20%5Cleft(u_%7B%5Ctheta%7D(-1%3B%20%5Ctheta)%5E2%20%2B%20u_%7B%5Ctheta%7D(1%3B%20%5Ctheta)%5E2%20%5Cright)%0A%20%20%20%20%5Cend%7Bequation*%7D%0A%0A%20%20%20%20We%20start%20by%20importing%20necessary%20functionality%20from%20both%20%5Bjax%5D(https%3A%2F%2Fdocs.jax.dev%2Fen%2Flatest%2Findex.html)%2C%20%5Boptax%5D(https%3A%2F%2Foptax.readthedocs.io%2Fen%2Flatest%2F)%2C%20%5Bflax%5D(https%3A%2F%2Fgithub.com%2Fgoogle%2Fflax)%20and%20Jaxfun.%20We%20also%20make%20use%20of%20Sympy%20in%20order%20to%20describe%20the%20equations.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20import%20jax%0A%0A%20%20%20%20jax.config.update(%22jax_enable_x64%22%2C%20True)%0A%0A%20%20%20%20import%20jax.numpy%20as%20jnp%0A%20%20%20%20import%20matplotlib.pyplot%20as%20plt%0A%20%20%20%20import%20sympy%20as%20sp%0A%20%20%20%20from%20flax%20import%20nnx%0A%0A%20%20%20%20from%20jaxfun.operators%20import%20Div%2C%20Grad%0A%20%20%20%20from%20jaxfun.pinns%20import%20FlaxFunction%2C%20Loss%2C%20MLPSpace%0A%20%20%20%20from%20jaxfun.pinns.optimizer%20import%20GaussNewton%2C%20Trainer%2C%20adam%2C%20lbfgs%0A%20%20%20%20from%20jaxfun.utils.common%20import%20lambdify%2C%20ulp%0A%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20Div%2C%0A%20%20%20%20%20%20%20%20FlaxFunction%2C%0A%20%20%20%20%20%20%20%20GaussNewton%2C%0A%20%20%20%20%20%20%20%20Grad%2C%0A%20%20%20%20%20%20%20%20Loss%2C%0A%20%20%20%20%20%20%20%20MLPSpace%2C%0A%20%20%20%20%20%20%20%20Trainer%2C%0A%20%20%20%20%20%20%20%20adam%2C%0A%20%20%20%20%20%20%20%20jax%2C%0A%20%20%20%20%20%20%20%20jnp%2C%0A%20%20%20%20%20%20%20%20lambdify%2C%0A%20%20%20%20%20%20%20%20lbfgs%2C%0A%20%20%20%20%20%20%20%20nnx%2C%0A%20%20%20%20%20%20%20%20plt%2C%0A%20%20%20%20%20%20%20%20sp%2C%0A%20%20%20%20%20%20%20%20ulp%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Notice%20the%20%60MLPSpace%60%20class%2C%20which%20represents%20a%20functionspace%20for%20a%20regular%20multilayer%20perceptron.%20The%20space%20will%20make%20use%20of%20a%20subclass%20of%20the%20%5Bflax%5D(https%3A%2F%2Fflax.readthedocs.io%2Fen%2Flatest)%20%5Bnnx.Module%5D(https%3A%2F%2Fflax.readthedocs.io%2Fen%2Flatest%2Fapi_reference%2Fflax.nnx%2Fmodule.html%23module-flax.nnx).%0A%0A%20%20%20%20This%20space%20holds%20information%20about%20the%20input%2C%20output%20and%20hidden%20layers%20in%20the%20neural%20network.%20Here%20we%20create%20an%20MLP%20for%20a%20one-dimensional%20problem%20(one%20input%20variable)%20and%2016%20neurons%20for%20a%20single%20hidden%20layer.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(MLPSpace)%3A%0A%20%20%20%20V%20%3D%20MLPSpace(%5B16%5D%2C%20dims%3D1%2C%20name%3D%22V%22)%0A%20%20%20%20return%20(V%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20It%20is%20possible%20to%20use%20several%20hidden%20layers%2C%20for%20example%20by%20choosing%20%60V%20%3D%20MLPSpace(%5B8%2C%208%2C%208%5D%2C%20dims%3D1)%60.%0A%0A%20%20%20%20The%20MLP%20function%20space%20is%20subsequently%20used%20to%20create%20a%20trial%20function%20for%20the%20neural%20network.%20The%20function%20%60v%60%20below%20holds%20all%20the%20unknown%20weights%20in%20the%20MLP.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(FlaxFunction%2C%20V%2C%20nnx)%3A%0A%20%20%20%20v%20%3D%20FlaxFunction(%0A%20%20%20%20%20%20%20%20V%2C%0A%20%20%20%20%20%20%20%20rngs%3Dnnx.Rngs(1001)%2C%0A%20%20%20%20%20%20%20%20name%3D%22v%22%2C%0A%20%20%20%20%20%20%20%20fun_str%3D%22phi%22%2C%0A%20%20%20%20%20%20%20%20kernel_init%3Dnnx.initializers.xavier_uniform()%2C%0A%20%20%20%20)%0A%20%20%20%20return%20(v%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Inside%20%60v%60%2C%20the%20regular%20flax%20nnx%20module%20is%20accessible%20through%20%60v.module%60.%0A%0A%20%20%20%20We%20will%20test%20the%20solver%20using%20a%20known%20manufactured%20solution.%20We%20can%20use%20any%20solution%2C%20but%20it%20should%20be%20continuous%20and%20the%20solution%20needs%20to%20use%20the%20same%20symbols%20as%20Jaxfun.%20Below%20we%20choose%20a%20mixture%20of%20a%20second%20order%20polynomial%20(to%20get%20the%20correct%20boundary%20condition)%20an%20exponential%20and%20a%20cosine%20function.%20This%20function%20is%20continuous%2C%20but%20it%20requires%20quite%20a%20few%20unknowns%20in%20order%20to%20get%20a%20decent%20solution.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(V%2C%20sp)%3A%0A%20%20%20%20x%20%3D%20V.system.x%0A%20%20%20%20ue%20%3D%20(1%20-%20x**2)%20*%20sp.exp(sp.cos(2%20*%20sp.pi%20*%20x))%0A%20%20%20%20return%20ue%2C%20x%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20equation%20to%20solve%20is%20now%20described%20in%20strong%20form%20using%20the%20residual%20%24%5Cmathcal%7BR%7Du_%7B%5Ctheta%7D%24.%20Note%20that%20we%20create%20the%20right%20hand%20side%20function%20%24f(x)%24%20from%20the%20manufactured%20solution.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(Div%2C%20Grad%2C%20ue%2C%20v)%3A%0A%20%20%20%20alpha%20%3D%201%0A%20%20%20%20fe%20%3D%20Div(Grad(ue))%20%2B%20alpha%20*%20ue%0A%20%20%20%20residual%20%3D%20Div(Grad(v))%20%2B%20alpha%20*%20v%20-%20fe%0A%20%20%20%20return%20alpha%2C%20residual%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20two%20operators%20%5BDiv%5D(https%3A%2F%2Fgithub.com%2FspectralDNS%2Fjaxfun%2Fblob%2Fmain%2Fjaxfun%2Foperators.py)%20and%20%5BGrad%5D(https%3A%2F%2Fgithub.com%2FspectralDNS%2Fjaxfun%2Fblob%2Fmain%2Fjaxfun%2Foperators.py)%20are%20defined%20in%20Jaxfun.%20For%20a%201D%20problem%20on%20the%20straight%20line%20there%20is%20no%20difference%20from%20writing%20the%20residual%20simply%20as%0A%0A%20%20%20%20residual%20%3D%20sp.diff(v%2C%20x%2C%202)%20%2B%20alpha%20*%20v%20-%20(sp.diff(ue%2C%20x%2C%202)%20%2B%20alpha*ue)%0A%0A%20%20%20%20We%20can%20look%20at%20the%20residual%20in%20code%3A%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(residual)%3A%0A%20%20%20%20residual%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Note%20that%20%24%5Cnabla%20%5Ccdot%24%20represents%20divergence%20and%20%24%5Cnabla%20v%24%20represents%20the%20gradient%20of%20the%20scalar%20field%20%24v%24.%20The%20neural%20network%20function%20%24v%24%20is%20written%20as%20%24v(x%3B%20V)%24%20since%20%24v%24%20is%20a%20function%20of%20%24x%24%20and%20it%20is%20a%20function%20on%20the%20space%20%24V%24.%20The%20residual%20above%20is%20in%20unevaluated%20form.%20We%20can%20evaluate%20it%20using%20the%20Sympy%20function%20%60doit%60%3A%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(residual)%3A%0A%20%20%20%20residual.doit()%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Note%20that%20Sympy%20here%20evaluates%20derivatives%20to%20the%20best%20of%20its%20abilities%2C%20and%20the%20neural%20network%20function%20%24v%24%20has%20been%20replaced%20by%20the%20expression%20%24%5Cphi(x)%24.%20This%20is%20because%20%60phi%60%20is%20set%20as%20%60fun_str%60%20for%20the%20created%20%60FlaxFunction%60%20%60v%60.%0A%0A%20%20%20%20We%20need%20training%20points%20inside%20the%20domain%20in%20order%20to%20solve%20the%20least%20squares%20problem.%20Create%20random%20points%20%60xj%60%20using%20the%20helper%20class%20%60Line%60%20and%20an%20array%20%60xb%60%20that%20holds%20the%20coordinates%20of%20the%20two%20boundaries.%20Note%20that%20the%20argument%20to%20%60mesh.get_points%60%20is%20the%20total%20number%20of%20points%20in%20the%20mesh%2C%20including%20boundary%20points.%20Since%20it's%20a%201D%20mesh%2C%20the%20number%20of%20boundary%20points%20is%202.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(jax)%3A%0A%20%20%20%20from%20jaxfun.pinns.mesh%20import%20Line%0A%0A%20%20%20%20mesh%20%3D%20Line(-1%2C%201%2C%20key%3Djax.random.PRNGKey(2002))%0A%20%20%20%20xj%20%3D%20mesh.get_points(1200%2C%20domain%3D%22inside%22%2C%20kind%3D%22random%22)%0A%20%20%20%20xb%20%3D%20mesh.get_points(1200%2C%20domain%3D%22boundary%22)%0A%20%20%20%20return%20xb%2C%20xj%0A%0A%0A%40app.cell%0Adef%20_(xb)%3A%0A%20%20%20%20print(xb)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20have%20two%20coupled%20problems%20to%20solve%3A%20The%20equation%20defined%20by%20%60residual%60%20and%20the%20boundary%20conditions.%20In%20order%20to%20solve%20these%20problems%20we%20now%20make%20use%20of%20the%20%60Loss%60%20class%20and%20a%20feed%20the%20two%20problems%20to%20it.%20We%20also%20need%20to%20feed%20the%20correct%20collocation%20points%20to%20each%20problem%3A%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(Loss%2C%20residual%2C%20v%2C%20xb%2C%20xj)%3A%0A%20%20%20%20loss_fn%20%3D%20Loss((residual%2C%20xj)%2C%20(v%2C%20xb))%0A%20%20%20%20loss_fn.residuals%0A%20%20%20%20return%20(loss_fn%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20first%20residual%20represents%20the%20equation%20to%20solve.%20It%20contains%20two%20subequations%20representing%20%24v''%24%20and%20%24%5Calpha%20v%24%2C%20wheras%20the%20constant%20part%20of%20the%20equation%20is%20placed%20in%20the%20target%20(an%20array%20of%20shape%20%24N-2%24).%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(loss_fn)%3A%0A%20%20%20%20loss_fn.residuals%5B0%5D.eqs%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(loss_fn)%3A%0A%20%20%20%20print(loss_fn.residuals%5B0%5D.target.shape)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20As%20such%20each%20%60Residual%60%20is%20a%20class%20holding%20the%20required%20functions%20in%20order%20to%20compute%20the%20residuals.%20The%20first%20one%20computes%20%24%5Cmathcal%7BR%7Du_%7B%5Ctheta%7D(x)%24%20and%20the%20other%20the%20boundary%20terms%20%24u_%7B%5Ctheta%7D(-1)%2Bu_%7B%5Ctheta%7D(1)%24.%20Calling%20%60loss_fn(v.module)%60%20returns%0A%0A%20%20%20%20%24%24%0A%20%20%20%20%5Cfrac%7B1%7D%7BN-2%7D%5Csum_%7Bi%3D0%7D%5E%7BN-3%7D%20%5Cmathcal%7BR%7Du_%7B%5Ctheta%7D(x_i%3B%20%5Ctheta)%5E2%20%2B%20%5Cfrac%7B1%7D%7B2%7D%20%5Cleft(u_%7B%5Ctheta%7D(-1%3B%20%5Ctheta)%5E2%20%2B%20u_%7B%5Ctheta%7D(1%3B%20%5Ctheta)%5E2%20%5Cright)%0A%20%20%20%20%24%24%0A%0A%20%20%20%20In%20order%20to%20solve%20the%20least%20squares%20problem%20we%20need%20an%20optimizer.%20Any%20%60optax%60%20optimizer%20may%20be%20used%2C%20but%20we%20will%20start%20with%20Adam%2C%20and%20then%20switch%20to%20a%20more%20accurate%20optimizer%20after%20a%20while.%20We%20first%20run%205000%20epochs%20with%20Adam%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(Trainer%2C%20adam%2C%20loss_fn%2C%20v)%3A%0A%20%20%20%20trainer%20%3D%20Trainer(loss_fn)%0A%20%20%20%20opt_adam%20%3D%20adam(v.module%2C%20learning_rate%3D0.001)%0A%20%20%20%20trainer.train(opt_adam%2C%205000%2C%20epoch_print%3D1000)%0A%20%20%20%20return%20(trainer%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20Adam%20optimizer%20is%20good%20at%20eliminating%20local%20minima%20and%20as%20such%20it%20is%20good%20at%20finding%20a%20solution%20that%20is%20close%20to%20the%20global%20minimum.%20However%2C%20Adam%20is%20only%20first%20order%20and%20not%20able%20to%20find%20a%20very%20accurate%20solution.%20For%20this%20we%20need%20either%20a%20quasi-Newton%20or%20a%20Newton%20optimizer.%20We%20start%20with%20the%20limited-memory%20BFGS%20optimizer%20and%20take%2010000%20more%20epochs.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(lbfgs%2C%20trainer%2C%20v)%3A%0A%20%20%20%20opt_lbfgs%20%3D%20lbfgs(v.module%2C%20memory_size%3D20)%0A%20%20%20%20trainer.train(opt_lbfgs%2C%2010000%2C%20epoch_print%3D1000)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20It%20is%20possible%20to%20run%20even%20more%20BFGS%20epoch%20to%20further%20polish%20this%20root.%20However%2C%20we%20will%20in%20the%20end%20switch%20to%20an%20even%20more%20accurate%20Newton%20optimizer.%20Since%20the%20Newton%20optimizer%20is%20costly%2C%20we%20run%20only%2010%20epochs.%20The%20Newton%20optimizer%20should%20only%20be%20used%20when%20the%20residual%20is%20already%20close%20to%20zero.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(GaussNewton%2C%20trainer%2C%20v)%3A%0A%20%20%20%20opt_hess%20%3D%20GaussNewton(v.module%2C%20use_lstsq%3DFalse%2C%20cg_max_iter%3D500)%0A%20%20%20%20trainer.train(opt_hess%2C%2010%2C%20epoch_print%3D1)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Running%20even%20more%20epochs%20the%20solution%20will%20become%20even%20more%20accurate.%0A%0A%20%20%20%20We%20can%20now%20compute%20the%20%24L%5E2%24%20error%20norm%20by%20comparing%20to%20the%20exact%20solution%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(jnp%2C%20lambdify%2C%20ue%2C%20v%2C%20x%2C%20xj)%3A%0A%20%20%20%20uej%20%3D%20lambdify(x%2C%20ue)(xj)%20%20%23%20Exact%0A%20%20%20%20print(%22Error%22%2C%20jnp.linalg.norm(v.module(xj)%20-%20uej)%20%2F%20jnp.sqrt(len(xj)))%0A%20%20%20%20return%20(uej%2C)%0A%0A%0A%40app.cell%0Adef%20_(jnp%2C%20lambdify%2C%20plt%2C%20ue%2C%20v%2C%20x)%3A%0A%20%20%20%20xa%20%3D%20jnp.linspace(-1%2C%201%2C%20100)%5B%3A%2C%20None%5D%0A%20%20%20%20uea%20%3D%20lambdify(x%2C%20ue)(xa)%0A%20%20%20%20fig%2C%20(ax1%2C%20ax2)%20%3D%20plt.subplots(1%2C%202%2C%20figsize%3D(10%2C%204))%0A%20%20%20%20ax1.plot(xa%2C%20v.module(xa)%2C%20%22b%22%2C%20label%3D%22PINNs%22)%0A%20%20%20%20ax1.plot(xa%2C%20uea%2C%20%22ro%22%2C%20label%3D%22Exact%22)%0A%20%20%20%20ax1.legend()%0A%20%20%20%20ax2.plot(xa%2C%20v.module(xa)%20-%20uea%2C%20%22b*%22)%0A%20%20%20%20ax1.set_title(%22Solutions%22)%0A%20%20%20%20ax2.set_title(%22Error%22)%0A%20%20%20%20plt.show()%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20residuals%20of%20each%20problem%20may%20also%20be%20computed%20using%20the%20%60loss_fn%60%20class.%20The%20residuals%20for%20the%20two%20boundary%20conditions%20are%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(loss_fn%2C%20v)%3A%0A%20%20%20%20loss_fn.compute_residual_i(v.module%2C%201)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%20Spectral%20least%20squares%20solver%0A%0A%20%20%20%20The%20neural%20network%20is%20capturing%20the%20solution%20quite%20well%2C%20but%20the%20convergence%20is%20quite%20slow.%20We%20know%20that%20a%20problem%20like%20the%20Helmholtz%20equation%20with%20a%20continuous%20solution%20should%20be%20very%20well%20captured%20using%20Legendre%20or%20Chebyshev%20basis%20functions%2C%20that%20have%20much%20better%20approximation%20properties%20than%20the%20neural%20network.%20Using%20Jaxfun%20we%20can%20solve%20this%20problem%20with%20the%20Galerkin%20method%2C%20but%20we%20can%20also%20use%20the%20least%20squares%20formulation%20similar%20to%20as%20above.%0A%0A%20%20%20%20The%20solver%20below%20is%20using%20simply%0A%0A%20%20%20%20%24%24%0A%20%20%20%20u_%7B%5Ctheta%7D(x)%20%3D%20%5Csum_%7Bi%3D0%7D%5E%7BN-1%7D%20%5Chat%7Bu%7D_i%20L_i(x)%0A%20%20%20%20%24%24%0A%0A%20%20%20%20where%20%24L_i(x)%24%20is%20the%20i'th%20Legendre%20polynomial%20and%20%24%5Chat%7Bu%7D_i%24%20are%20the%20unknowns.%0A%0A%20%20%20%20The%20least%20squares%20implementation%20goes%20as%20follows%3A%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(%0A%20%20%20%20Div%2C%0A%20%20%20%20FlaxFunction%2C%0A%20%20%20%20GaussNewton%2C%0A%20%20%20%20Grad%2C%0A%20%20%20%20Loss%2C%0A%20%20%20%20Trainer%2C%0A%20%20%20%20adam%2C%0A%20%20%20%20alpha%2C%0A%20%20%20%20lbfgs%2C%0A%20%20%20%20nnx%2C%0A%20%20%20%20ue%2C%0A%20%20%20%20ulp%2C%0A%20%20%20%20xb%2C%0A%20%20%20%20xj%2C%0A)%3A%0A%20%20%20%20from%20jaxfun.galerkin.Legendre%20import%20Legendre%0A%0A%20%20%20%20VN%20%3D%20Legendre(60)%0A%20%20%20%20w%20%3D%20FlaxFunction(%0A%20%20%20%20%20%20%20%20VN%2C%20rngs%3Dnnx.Rngs(1001)%2C%20kernel_init%3Dnnx.initializers.xavier_uniform()%2C%20name%3D%22v%22%0A%20%20%20%20)%0A%20%20%20%20res%20%3D%20Div(Grad(w))%20%2B%20alpha%20*%20w%20-%20(Div(Grad(ue))%20%2B%20alpha%20*%20ue)%0A%20%20%20%20loss_fn_1%20%3D%20Loss((res%2C%20xj)%2C%20(w%2C%20xb))%0A%20%20%20%20trainer_1%20%3D%20Trainer(loss_fn_1)%0A%20%20%20%20opt_adam_1%20%3D%20adam(w.module)%0A%20%20%20%20trainer_1.train(opt_adam_1%2C%205000%2C%20epoch_print%3D1000)%0A%20%20%20%20opt_lbfgs_1%20%3D%20lbfgs(w.module)%0A%20%20%20%20trainer_1.train(opt_lbfgs_1%2C%201000%2C%20epoch_print%3D100)%0A%20%20%20%20opt_hess_1%20%3D%20GaussNewton(w.module%2C%20use_lstsq%3DTrue%2C%20cg_max_iter%3D100)%0A%20%20%20%20trainer_1.train(opt_hess_1%2C%204%2C%20epoch_print%3D1%2C%20abs_limit_loss%3Dulp(1))%0A%20%20%20%20return%20Legendre%2C%20w%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Note%20that%20with%20only%202%20Newton%20iterations%20the%20error%20plunges%20to%20zero.%20This%20is%20because%20this%20method%20is%20an%20exact%20solver%20for%20linear%20equations%20and%20expansions%20like%20the%20Legendre%20polynomials.%20Comparing%20now%20with%20the%20exact%20solution%20we%20get%20a%20very%20accurate%20%24L%5E2%24%20error%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(jnp%2C%20uej%2C%20w%2C%20xj)%3A%0A%20%20%20%20print(%22Error%22%2C%20jnp.linalg.norm(w.module(xj)%20-%20uej)%20%2F%20jnp.sqrt(len(xj)))%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(plt%2C%20uej%2C%20w%2C%20xj)%3A%0A%20%20%20%20plt.plot(xj%2C%20w.module(xj)%20-%20uej%2C%20%22b*%22)%0A%20%20%20%20plt.title(%22Error%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%20Implicit%20spectral%20least%20squares%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(Div%2C%20Grad%2C%20Legendre%2C%20alpha%2C%20jnp%2C%20lambdify%2C%20ue%2C%20x%2C%20xj)%3A%0A%20%20%20%20from%20jaxfun.galerkin%20import%20FunctionSpace%2C%20TestFunction%2C%20TrialFunction%2C%20inner%0A%0A%20%20%20%20D%20%3D%20FunctionSpace(%0A%20%20%20%20%20%20%20%2060%2C%20Legendre%2C%20bcs%3D%7B%22left%22%3A%20%7B%22D%22%3A%200%7D%2C%20%22right%22%3A%20%7B%22D%22%3A%200%7D%7D%2C%20name%3D%22D%22%2C%20fun_str%3D%22psi%22%0A%20%20%20%20)%0A%20%20%20%20q%20%3D%20TestFunction(D%2C%20name%3D%22v%22)%0A%20%20%20%20u%20%3D%20TrialFunction(D%2C%20name%3D%22u%22)%0A%20%20%20%20A%2C%20L%20%3D%20inner(%0A%20%20%20%20%20%20%20%20(Div(Grad(q))%20%2B%20alpha%20*%20q)%20*%20(Div(Grad(u))%20%2B%20alpha%20*%20u)%0A%20%20%20%20%20%20%20%20-%20(Div(Grad(q))%20%2B%20alpha%20*%20q)%20*%20(Div(Grad(ue))%20%2B%20alpha%20*%20ue)%2C%0A%20%20%20%20%20%20%20%20sparse%3DTrue%2C%0A%20%20%20%20%20%20%20%20sparse_tol%3D1000%2C%0A%20%20%20%20%20%20%20%20return_all_items%3DFalse%2C%0A%20%20%20%20)%0A%20%20%20%20uh%20%3D%20jnp.linalg.solve(A.todense()%2C%20L)%0A%20%20%20%20uj%20%3D%20D.evaluate(xj%2C%20uh)%0A%20%20%20%20uej_1%20%3D%20lambdify(x%2C%20ue)(xj)%0A%20%20%20%20error%20%3D%20jnp.linalg.norm(uj%20-%20uej_1)%20%2F%20jnp.sqrt(len(xj))%0A%20%20%20%20print(error)%0A%20%20%20%20return%20uej_1%2C%20uj%0A%0A%0A%40app.cell%0Adef%20_(plt%2C%20uej_1%2C%20uj%2C%20xj)%3A%0A%20%20%20%20plt.plot(xj%2C%20uj%20-%20uej_1%2C%20%22b*%22)%0A%20%20%20%20plt.title(%22Error%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20solution%20with%20implicit%20Galerkin%20or%20least%20squares%20is%20about%20as%20accurate%20as%20the%20regular%20least%20squares.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_()%3A%0A%20%20%20%20import%20marimo%20as%20mo%0A%0A%20%20%20%20return%20(mo%2C)%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
6e8108969ddbc0f8e1d47b80dfd8c856