import%20marimo%0A%0A__generated_with%20%3D%20%220.19.8%22%0Aapp%20%3D%20marimo.App()%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%20Lid%20driven%20cavity%0A%0A%20%20%20%20In%20this%20notebook%20we%20use%20Physics%20Informed%20Neural%20Networks%20(PINNS)%20to%20solve%20the%20Navier-Stokes%20equations%20in%20a%20rectangular%20domain.%20The%20Navier-Stokes%20problem%20to%20solve%20is%0A%0A%20%20%20%20%5Cbegin%7Balign*%7D%0A%20%20%20%20(%5Cboldsymbol%7Bu%7D%20%5Ccdot%20%5Cnabla)%20%5Cboldsymbol%7Bu%7D%20-%20%5Cnu%20%5Cnabla%5E2%20%5Cboldsymbol%7Bu%7D%20%2B%20%5Cnabla%20p%20%26%3D%200%2C%20%5Cquad%20%5Cboldsymbol%7Bx%7D%20%5Cin%20%5COmega%20%3D%20(-1%2C%201)%5E2%5C%5C%0A%20%20%20%20%5Cnabla%20%5Ccdot%20%5Cboldsymbol%7Bu%7D%20%26%3D%200%2C%20%5Cquad%20%5Cboldsymbol%7Bx%7D%20%5Cin%20%5COmega%20%3D%20(-1%2C%201)%5E2%20%5C%5C%0A%20%20%20%20%5Cend%7Balign*%7D%0A%0A%20%20%20%20where%20%24%5Cnu%24%20is%20a%20constant%20kinematic%20viscosity%2C%20%24p(%5Cboldsymbol%7Bx%7D)%24%20is%20pressure%2C%20%24%5Cboldsymbol%7Bu%7D(%5Cboldsymbol%7Bx%7D)%20%3D%20u_x(%5Cboldsymbol%7Bx%7D)%20%5Cboldsymbol%7Bi%7D%20%2B%20u_y(%5Cboldsymbol%7Bx%7D)%20%5Cboldsymbol%7Bj%7D%24%20is%20the%20velocity%20vector%20and%20the%20position%20vector%20%24%5Cboldsymbol%7Bx%7D%20%3D%20x%20%5Cboldsymbol%7Bi%7D%20%2B%20y%20%5Cboldsymbol%7Bj%7D%24%2C%20with%20unit%20vectors%20%24%5Cboldsymbol%7Bi%7D%24%20and%20%24%5Cboldsymbol%7Bj%7D%24.%20The%20Dirichlet%20boundary%20condition%20for%20the%20velocity%20vector%20is%20zero%20everywhere%2C%20except%20for%20the%20top%20lid%2C%20where%20%24%5Cboldsymbol%7Bu%7D(x%2C%20y%3D1)%20%3D%20(1-x)%5E2(1%2Bx)%5E2%20%5Cboldsymbol%7Bi%7D%24.%20There%20is%20no%20boundary%20condition%20on%20pressure%2C%20but%20since%20the%20pressure%20is%20only%20present%20inside%20a%20gradient%2C%20it%20can%20only%20be%20found%20up%20to%20a%20constant.%20Hence%2C%20we%20specify%20that%20%24p(x%3D0%2C%20y%3D0)%20%3D%200%24.%0A%0A%20%20%20%20We%20start%20the%20implementation%20by%20importing%20necessary%20modules%20such%20as%20%5Bjax%5D(https%3A%2F%2Fdocs.jax.dev%2Fen%2Flatest%2Findex.html)%2C%20%5Bflax%5D(https%3A%2F%2Fflax.readthedocs.io%2F)%2C%20where%20the%20latter%20is%20a%20module%20that%20provides%20neural%20networks%20for%20JAX.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20import%20jax%0A%20%20%20%20import%20jax.numpy%20as%20jnp%0A%20%20%20%20from%20flax%20import%20nnx%0A%0A%20%20%20%20jax.config.update(%22jax_enable_x64%22%2C%20True)%0A%20%20%20%20return%20jnp%2C%20nnx%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20will%20solve%20the%20Navier-Stokes%20equations%20using%20a%20multilevel%20perceptron%20neural%20network%2C%20where%20the%20solution%20will%20be%20approximated%20as%0A%0A%20%20%20%20%24%24%0A%20%20%20%20F_%7B%5Ctheta%7D(%5Cboldsymbol%7Bx%7D%3B%20%5Ctheta)%20%3D%20%20W%5E%7BL%7D%20%5Csigma(%20W%5E%7BL-1%7D%20%5Cldots%20%5Csigma(%20W%5E1%20%5Cboldsymbol%7Bx%7D%20%2B%20%5Cboldsymbol%7Bb%7D%5E1)%20%5Cldots%20%2B%20%5Cboldsymbol%7Bb%7D%5E%7BL-1%7D)%20%20%2B%20%5Cboldsymbol%7Bb%7D%5EL%0A%20%20%20%20%24%24%0A%0A%20%20%20%20where%20%24%5Ctheta%20%3D%20%5C%7BW%5El%2C%20%5Cboldsymbol%7Bb%7D%5El%5C%7D_%7Bl%3D1%7D%5EL%24%20represents%20all%20the%20unknowns%20in%20the%20model%20and%20%24W%5El%2C%20%5Cboldsymbol%7Bb%7D%5El%24%20represents%20the%20weights%20and%20biases%20on%20level%20%24l%24.%20The%20model%20contains%20both%20velocity%20components%20and%20the%20pressure%2C%20for%20a%20total%20of%20three%20scalar%20outputs%3A%20%24u_%7Bx%7D(%5Cboldsymbol%7Bx%7D)%2C%20u_y(%5Cboldsymbol%7Bx%7D)%24%20and%20%24p(%5Cboldsymbol%7Bx%7D)%24.%20Hence%20%24F_%7B%5Ctheta%7D%3A%20%5Cmathbb%7BR%7D%5E2%20%5Crightarrow%20%5Cmathbb%7BR%7D%5E3%24.%0A%0A%20%20%20%20We%20split%20the%20coupled%20%24F_%7B%5Ctheta%7D%24%20into%20velocity%20and%20pressure%0A%0A%20%20%20%20%24%24%0A%20%20%20%20F_%7B%5Ctheta%7D%20%3D%20%5Cboldsymbol%7Bu%7D_%7B%5Ctheta%7D%20%5Ctimes%20p_%7B%5Ctheta%7D%0A%20%20%20%20%24%24%0A%0A%20%20%20%20where%20%24%5Cboldsymbol%7Bu%7D_%7B%5Ctheta%7D%3A%20%5Cmathbb%7BR%7D%5E2%20%5Crightarrow%20%5Cmathbb%7BR%7D%5E2%24%20is%20a%20vector%20function%20and%20%24p_%7B%5Ctheta%7D%3A%20%5Cmathbb%7BR%7D%5E2%20%5Crightarrow%20%5Cmathbb%7BR%7D%24%20is%20a%20scalar%20function.%0A%0A%20%20%20%20We%20start%20the%20implementation%20by%20creating%20multilayer%20perceptron%20functionspaces%20for%20%24%5Cboldsymbol%7Bu%7D_%7B%5Ctheta%7D%24%20and%20%24p_%7B%5Ctheta%7D%24%2C%20and%20then%20combining%20these%20to%20a%20space%20for%20%24F_%7B%5Ctheta%7D%24%20using%20a%20Cartesian%20product%3A%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(nnx)%3A%0A%20%20%20%20from%20jaxfun.pinns%20import%20FlaxFunction%2C%20MLPSpace%2C%20Comp%0A%0A%20%20%20%20V%20%3D%20MLPSpace(%5B16%5D%2C%20dims%3D2%2C%20rank%3D1%2C%20name%3D%22V%22)%20%20%23%20Vector%20space%20for%20velocity%0A%20%20%20%20Q%20%3D%20MLPSpace(%5B12%5D%2C%20dims%3D2%2C%20rank%3D0%2C%20name%3D%22Q%22)%20%20%23%20Scalar%20space%20for%20pressure%0A%0A%20%20%20%20u%20%3D%20FlaxFunction(V%2C%20%22u%22%2C%20rngs%3Dnnx.Rngs(2002))%0A%20%20%20%20p%20%3D%20FlaxFunction(Q%2C%20%22p%22%2C%20rngs%3Dnnx.Rngs(2002))%0A%20%20%20%20module%20%3D%20Comp(u%2C%20p)%0A%20%20%20%20return%20V%2C%20module%2C%20p%2C%20u%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Note%20that%20%60u%60%20and%20%60p%60%20both%20are%20%60FlaxFunction%60s%2C%20that%20are%20subclasses%20of%20the%20Sympy%20%5BFunction%5D(https%3A%2F%2Fdocs.sympy.org%2Flatest%2Fmodules%2Ffunctions%2Findex.html).%20However%2C%20in%20Jaxfun%20these%20functions%20have%20some%20additional%20properties%2C%20that%20makes%20it%20easy%20to%20describe%20equations.%0A%0A%20%20%20%20We%20can%20inspect%20%60u%60%20and%20%60p%60%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(p%2C%20u)%3A%0A%20%20%20%20from%20IPython.display%20import%20display%0A%0A%20%20%20%20display(u)%0A%20%20%20%20display(u.doit())%0A%20%20%20%20display(p)%0A%20%20%20%20display(p.doit())%0A%20%20%20%20return%20(display%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Note%20that%20%60u%60%20and%20%60p%60%20are%20in%20unevaluated%20state%2C%20whereas%20%60u.doit()%60%20and%20%60p.doit()%60%20returns%20sympy%20functions%20for%20the%20computational%20space.%20If%20we%20check%20the%20type%20of%20%60u.doit()%60%2C%20we%20get%20that%20it%20is%20a%20%60VectorAdd%60%2C%20because%20the%20vector%20is%20an%20addition%20of%20the%20two%20vector%20components%20in%20%24(u_x(x%2C%20y))%5Cboldsymbol%7Bi%7D%20%2B%20(u_y(x%2C%20y))%20%5Cboldsymbol%7Bj%7D%24.%20The%20three%20terms%20%24u_x(x%2C%20y)%2C%20u_y(x%2C%20y)%24%20and%20%24p(x%2C%20y)%24%20are%20all%20sympy%20functions%20and%20of%20type%20%60AppliedUndef%60.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20can%20now%20describe%20the%20Navier-Stokes%20equations%20using%20%60Div%60%2C%20%60Grad%60%20and%20%60Dot%60%20from%20Jaxfun.%20First%20specify%20the%20Reynolds%20number%2C%20the%20kinematic%20viscosity%20and%20then%20the%20equations%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(p%2C%20u)%3A%0A%20%20%20%20import%20sympy%20as%20sp%0A%20%20%20%20from%20jaxfun.operators%20import%20Div%2C%20Dot%2C%20Grad%2C%20Constant%0A%0A%20%20%20%20Re%20%3D%2010.0%20%20%23%20Define%20Reynolds%20number%0A%20%20%20%20nu%20%3D%20Constant(%0A%20%20%20%20%20%20%20%20%22nu%22%2C%202.0%20%2F%20Re%0A%20%20%20%20)%20%20%23%20Define%20kinematic%20viscosity.%20A%20number%20works%20as%20well%2C%20but%20the%20Constant%20prints%20better.%0A%20%20%20%20R1%20%3D%20Dot(Grad(u)%2C%20u)%20-%20nu%20*%20Div(Grad(u))%20%2B%20Grad(p)%0A%20%20%20%20R2%20%3D%20Div(u)%0A%20%20%20%20return%20R1%2C%20R2%2C%20sp%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Here%20%60R1%60%20represents%20the%20residual%20of%20the%20momentum%20vector%20equation%20%24%5Cmathcal%7BR%7D%5E1_%7B%5Ctheta%7D%24%2C%20whereas%20%60R2%60%20represents%20the%20residual%20of%20the%20scalar%20divergence%20constraint%20%24%5Cmathcal%7BR%7D%5E2_%7B%5Ctheta%7D%24.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20can%20inspect%20the%20residuals%3A%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(R1%2C%20display)%3A%0A%20%20%20%20display(R1)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%60R1%60%20represents%20a%20vector%20equation%20and%20we%20can%20expand%20it%20using%20%60doit%60%3A%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(R1)%3A%0A%20%20%20%20R1.doit()%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20divergence%20constraint%20is%20a%20scalar%20equation%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(R2%2C%20display)%3A%0A%20%20%20%20display(R2)%0A%20%20%20%20display(R2.doit())%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20To%20solve%20the%20equations%20we%20will%20use%20a%20least%20squares%20method%20and%20for%20this%20we%20need%20to%20create%20collocation%20points%20both%20inside%20and%20on%20the%20domain.%20There%20are%20some%20simple%20helper%20functions%20in%20%60jaxfun.pinns.mesh%60%20that%20can%20help%20use%20create%20such%20points.%20Below%20we%20create%20a%20total%20of%20%24N%5E2%24%20points%20in%20a%20rectangular%20mesh.%20We%20separate%20these%20points%20into%20%24N_i%3DN%5E2-4N%24%20points%20%60xyi%60%20inside%20the%20domain%20%24%5Cboldsymbol%7Bx%7D%5E%7B%5COmega%7D%20%3D%20%5C%7B(x_i%2C%20y_i)%5C%7D_%7Bi%3D0%7D%5E%7BN_i-1%7D%24%2C%20and%20%24N_b%3D4N%24%20points%20%60xyb%60%20on%20the%20boundary%20of%20the%20domain%20%24%5Cboldsymbol%7Bx%7D%5E%7B%5Cpartial%20%5COmega%7D%20%3D%20%5C%7B(x_i%2C%20y_i)%5C%7D_%7Bi%3D0%7D%5E%7B4N-1%7D%24%2C%20including%20the%20four%20corners.%20The%20last%20point%20%60xyp%60%20is%20simply%20origo%20and%20used%20to%20fix%20the%20pressure%20in%20a%20single%20point.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(jnp)%3A%0A%20%20%20%20from%20jaxfun.pinns.mesh%20import%20Rectangle%0A%0A%20%20%20%20N%20%3D%2020%0A%20%20%20%20mesh%20%3D%20Rectangle(-1%2C%201%2C%20-1%2C%201)%0A%20%20%20%20xyi%20%3D%20mesh.get_points(N%20*%20N%2C%204%20*%20N%2C%20domain%3D%22inside%22%2C%20kind%3D%22random%22)%0A%20%20%20%20xyb%20%3D%20mesh.get_points(N%20*%20N%2C%204%20*%20N%2C%20domain%3D%22boundary%22%2C%20kind%3D%22random%22)%0A%20%20%20%20xyp%20%3D%20jnp.array(%5B%5B0.0%2C%200.0%5D%5D)%0A%20%20%20%20return%20xyb%2C%20xyi%2C%20xyp%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20boundary%20conditions%20on%20the%20velocity%20vector%20needs%20to%20be%20specified%20and%20to%20this%20end%20we%20can%20use%20a%20boundary%20function%20%24ue(x%2C%20y)%20%3D%20(1-x)%5E2(1%2Bx)%5E2%20%5Cboldsymbol%7Bi%7D%24%20created%20as%20a%20Sympy%20function%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(V%2C%20sp)%3A%0A%20%20%20%20x%2C%20y%20%3D%20V.system.base_scalars()%0A%20%20%20%20ub%20%3D%20sp.Mul(%0A%20%20%20%20%20%20%20%20sp.Piecewise((0%2C%20y%20%3C%201)%2C%20((1%20-%20x)%20**%202%20*%20(1%20%2B%20x)%20**%202%2C%20True))%2C%20V.system.i%0A%20%20%20%20)%0A%20%20%20%20return%20(ub%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20unknowns%20will%20now%20be%20found%20using%20the%20least%20squares%20method%2C%20which%20is%20to%20minimize%0A%0A%20%20%20%20%5Cbegin%7Bequation*%7D%0A%20%20%20%20%5Cunderset%7B%5Ctheta%20%5Cin%20%5Cmathbb%7BR%7D%5EM%7D%7B%5Ctext%7Bminimize%7D%7D%5C%2C%20L(%5Ctheta)%3A%3D%5Cfrac%7B1%7D%7BN_i%7D%5Csum_%7Bk%3D0%7D%5E1%5Csum_%7Bi%3D0%7D%5E%7BN_i-1%7D%20%5Cmathcal%7BR%7D%5E%7B1%7D_%7B%5Ctheta%7D%20(%5Cboldsymbol%7Bx%7D%5E%7B%5COmega%7D_i)%5E2%20%5Ccdot%20%5Cboldsymbol%7Bi%7D_k%20%2B%20%5Cfrac%7B1%7D%7BN_i%7D%5Csum_%7Bi%3D0%7D%5E%7BN_i-1%7D%20%5Cmathcal%7BR%7D%5E%7B2%7D_%7B%5Ctheta%7D(%5Cboldsymbol%7Bx%7D%5E%7B%5COmega%7D_i)%5E2%2B%20%5Cfrac%7B1%7D%7BN_b%7D%20%5Csum_%7Bk%3D0%7D%5E1%20%5Csum_%7Bi%3D0%7D%5E%7BN_b-1%7D%20(%5Cboldsymbol%7Bu%7D_%7B%5Ctheta%7D(%5Cboldsymbol%7Bx%7D%5E%7B%5Cpartial%20%5COmega%7D_i)%20-%20%5Cboldsymbol%7Bu%7D_b(%5Cboldsymbol%7Bx%7D%5E%7B%5Cpartial%20%5COmega%7D_i))%5E2%20%5Ccdot%20%5Cboldsymbol%7Bi%7D_k%20%2B%20p_%7B%5Ctheta%7D(0%2C%200)%0A%20%20%20%20%5Cend%7Bequation*%7D%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20where%20%24%5Cboldsymbol%7Bi%7D_0%20%3D%20%5Cboldsymbol%7Bi%7D%24%20and%20%24%5Cboldsymbol%7Bi%7D_1%20%3D%20%5Cboldsymbol%7Bj%7D%24.%0A%20%20%20%20We%20define%20the%20minimization%20problem%20using%20the%20jaxfun%20class%20%60Loss%60%2C%20which%20takes%20as%20arguments%20tuples%20containing%20the%20residual%2C%20the%20collocation%20points%20to%20use%20for%20that%20residual%2C%20the%20target%20(defaults%20to%20zero)%20and%20optionally%20some%20weights.%20The%20weights%20may%20be%20constants%20or%201D%20arrays%20of%20length%20the%20number%20of%20collocation%20points.%20If%20weights%20are%20provided%2C%20then%20these%20are%20applied%20to%20the%20squared%20residuals%20in%20each%20term%20above.%20For%20example%2C%20the%20divergence%20loss%20becomes%0A%0A%20%20%20%20%24%24%0A%20%20%20%20%5Cfrac%7B1%7D%7BN_i%7D%5Csum_%7Bi%3D0%7D%5E%7BN_i-1%7D%20%5Comega_i%20%5Cmathcal%7BR%7D%5E%7B2%7D_%7B%5Ctheta%7D(%5Cboldsymbol%7Bx%7D%5E%7B%5COmega%7D_i)%5E2%0A%20%20%20%20%24%24%0A%0A%20%20%20%20using%20weights%20%24%5C%7B%5Comega_i%5C%7D_%7Bi%3D0%7D%5E%7BN_i-1%7D%24.%20Below%20the%20pressure%20anchor%20is%20simply%20weighted%20with%20a%20constant%20factor%205.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(R1%2C%20R2%2C%20p%2C%20u%2C%20ub%2C%20xyb%2C%20xyi%2C%20xyp)%3A%0A%20%20%20%20from%20jaxfun.pinns%20import%20Loss%0A%0A%20%20%20%20loss_fn%20%3D%20Loss((R1%2C%20xyi)%2C%20(R2%2C%20xyi)%2C%20(u%20-%20ub%2C%20xyb)%2C%20(p%2C%20xyp%2C%200%2C%205))%0A%20%20%20%20return%20(loss_fn%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20There%20are%20six%20scalar%20losses%20computed%20with%20%60loss_fn%60%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(loss_fn)%3A%0A%20%20%20%20print(loss_fn.residuals)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Each%20loss%20may%20be%20computed%20individually%2C%20or%20we%20can%20compute%20the%20whole%20sum%20%24L(%5Ctheta)%24.%20Since%20we%20have%20not%20started%20the%20least%20squares%20solver%20yet%2C%20the%20module%20for%20now%20only%20contains%20randomly%20initialized%20weights%20and%20the%20loss%20is%20significant%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(loss_fn%2C%20module)%3A%0A%20%20%20%20loss_fn(module)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20To%20create%20a%20solver%2C%20we%20use%20the%20%5Boptax%5D(https%3A%2F%2Foptax.readthedocs.io%2Fen%2Flatest%2F)%20module%2C%20and%20start%20with%20the%20Adam%20optimizer%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(loss_fn%2C%20module)%3A%0A%20%20%20%20from%20jaxfun.pinns.optimizer%20import%20Trainer%2C%20adam%0A%0A%20%20%20%20trainer%20%3D%20Trainer(loss_fn)%0A%20%20%20%20opt_adam%20%3D%20adam(module)%0A%20%20%20%20trainer.train(opt_adam%2C%205000%2C%20epoch_print%3D1000)%0A%20%20%20%20return%20(trainer%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20loss%20has%20now%20been%20reduced%2C%20but%20is%20still%20far%20from%20zero.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(loss_fn%2C%20module)%3A%0A%20%20%20%20loss_fn(module)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20To%20refine%20the%20solution%2C%20we%20use%20a%20low-memory%20BFGS%20solver%3A%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(module%2C%20trainer)%3A%0A%20%20%20%20from%20jaxfun.pinns.optimizer%20import%20lbfgs%0A%0A%20%20%20%20opt_lbfgs%20%3D%20lbfgs(module%2C%20memory_size%3D100)%0A%20%20%20%20trainer.train(opt_lbfgs%2C%205000%2C%20epoch_print%3D1000)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(loss_fn%2C%20module)%3A%0A%20%20%20%20loss_fn(module)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20And%20finally%20an%20incomplete%20Newton%20method%2C%20which%20is%20still%20quite%20slow%20because%20we%20have%20not%20yet%20implemented%20a%20preconditioner.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(module%2C%20trainer)%3A%0A%20%20%20%20from%20jaxfun.pinns.optimizer%20import%20GaussNewton%0A%0A%20%20%20%20opt_hess%20%3D%20GaussNewton(module%2C%20use_lstsq%3DFalse%2C%20cg_max_iter%3D100)%0A%20%20%20%20trainer.train(opt_hess%2C%20100%2C%20epoch_print%3D10)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20can%20now%20visualize%20the%20solution%20using%20%5Bmatplotlib%5D(https%3A%2F%2Fmatplotlib.org%2F)%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(jnp%2C%20module)%3A%0A%20%20%20%20import%20matplotlib.pyplot%20as%20plt%0A%0A%20%20%20%20yj%20%3D%20jnp.linspace(-1%2C%201%2C%2050)%0A%20%20%20%20xx%2C%20yy%20%3D%20jnp.meshgrid(yj%2C%20yj%2C%20sparse%3DFalse%2C%20indexing%3D%22ij%22)%0A%20%20%20%20z%20%3D%20jnp.column_stack((xx.ravel()%2C%20yy.ravel()))%0A%20%20%20%20uvp%20%3D%20module(z)%0A%20%20%20%20fig%2C%20axs%20%3D%20plt.subplots(1%2C%203%2C%20sharex%3DTrue%2C%20sharey%3DTrue%2C%20figsize%3D(6%2C%202))%0A%20%20%20%20axs%5B0%5D.contourf(xx%2C%20yy%2C%20uvp%5B%3A%2C%200%5D.reshape(xx.shape)%2C%20100)%0A%20%20%20%20axs%5B0%5D.set_title(r%22%24u_x%24%22)%0A%20%20%20%20axs%5B1%5D.contourf(xx%2C%20yy%2C%20uvp%5B%3A%2C%201%5D.reshape(xx.shape)%2C%20100)%0A%20%20%20%20axs%5B1%5D.set_title(r%22%24u_y%24%22)%0A%20%20%20%20axs%5B2%5D.contourf(xx%2C%20yy%2C%20uvp%5B%3A%2C%202%5D.reshape(xx.shape)%2C%20100)%0A%20%20%20%20axs%5B2%5D.set_title(r%22%24p%24%22)%0A%20%20%20%20fig.set_tight_layout(%22tight%22)%0A%20%20%20%20plt.show()%0A%20%20%20%20return%20plt%2C%20uvp%2C%20xx%2C%20yy%0A%0A%0A%40app.cell%0Adef%20_(plt%2C%20uvp%2C%20xx%2C%20yy)%3A%0A%20%20%20%20plt.figure(figsize%3D(4%2C%204))%0A%20%20%20%20plt.quiver(xx%2C%20yy%2C%20uvp%5B%3A%2C%200%5D.reshape(xx.shape)%2C%20uvp%5B%3A%2C%201%5D.reshape(xx.shape))%0A%20%20%20%20plt.show()%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20can%20also%20compute%20the%20losses%20of%20each%20equation.%20Here%20we%20get%20the%20loss%20of%20each%20collocation%20point%20in%20%24%5Cboldsymbol%7Bx%7D%5E%7B%5COmega%7D%24%20for%20the%20momentum%20equation%20in%20the%20x-direction%3A%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(loss_fn%2C%20module)%3A%0A%20%20%20%20loss_fn.compute_residual_i(module%2C%200)%5B%3A10%5D%20%20%23%20plot%20only%2010%20numbers%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20loss%20can%20be%20plotted%20using%20for%20example%20scatter.%20Below%20is%20the%20scatter%20plot%20of%20the%20error%20in%20the%20divergence%20constraint.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(loss_fn%2C%20module%2C%20plt%2C%20xyi)%3A%0A%20%20%20%20plt.figure(figsize%3D(4%2C%203))%0A%20%20%20%20plt.scatter(*xyi.T%2C%20c%3Dloss_fn.compute_residual_i(module%2C%202)%2C%20s%3D20)%0A%20%20%20%20plt.colorbar()%0A%20%20%20%20plt.show()%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20And%20finally%20a%20vector%20plot%20of%20the%20velocity%20vectors%3A%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(jnp%2C%20plt%2C%20uvp%2C%20xx%2C%20yy)%3A%0A%20%20%20%20plt.figure(figsize%3D(4%2C%204))%0A%20%20%20%20plt.quiver(xx%2C%20yy%2C%20uvp%5B%3A%2C%200%5D%2C%20uvp%5B%3A%2C%201%5D%2C%20jnp.linalg.norm(uvp%5B%3A%2C%20%3A2%5D%2C%20axis%3D1))%0A%20%20%20%20plt.show()%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20import%20marimo%20as%20mo%0A%0A%20%20%20%20return%20(mo%2C)%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
f39a98317f2b0f59807afbfa30ad6697