diff --git a/MindFlow/applications/research/ns_cylinder_pinns/NS_Cylinder.ipynb b/MindFlow/applications/research/ns_cylinder_pinns/NS_Cylinder.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..d35840388a78b2900d7fbdaf5efdac3c68d87f14
--- /dev/null
+++ b/MindFlow/applications/research/ns_cylinder_pinns/NS_Cylinder.ipynb
@@ -0,0 +1,365 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "d51c9bf8",
+ "metadata": {},
+ "source": [
+ "## Viscous incompressible flow over a circular cylinder"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0de850bf",
+ "metadata": {},
+ "source": [
+ "## Environment Setup"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "03a34ffd",
+ "metadata": {},
+ "source": [
+ "The case requires Python 3.7 and MindSpore version 2.0.0 or above, and the Sciai toolkit must be installed."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2b9aef08",
+ "metadata": {},
+ "source": [
+ "## Overview"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d3ddb3b5",
+ "metadata": {},
+ "source": [
+ "The solution of fluid problems has long been a major challenge in the fields of science and engineering. It is crucial in areas such as aerospace and weather forecasting. In the past few decades, the solution of such problems has mainly relied on traditional numerical methods, including the Finite Difference Method (FDM), the Finite Volume Method (FVM), and the Finite Element Method (FEM). However, these methods also have certain limitations, such as their dependence on mesh and the difficulty of solving inverse problems."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "216f86c4",
+ "metadata": {},
+ "source": [
+ "In recent years, with the rapid development of deep learning technologies and computational resources, solving partial differential equations using neural networks has become a major research hotspot, achieving successful applications in many fields and serving as a strong complement to traditional paradigms."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4c694f52",
+ "metadata": {},
+ "source": [
+ "The viscous incompressible flow over a circular cylinder is governed by the Navier-Stokes equations, which describe the motion of the fluid around the cylinder. The Navier-Stokes equations are a highly nonlinear system of equations. In this case, the two-dimensional steady form of these equations is solved using Physics-Informed Neural Networks (PINNs) combined with volume weighting methods."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b59a7f54",
+ "metadata": {},
+ "source": [
+ "## Problem Description"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2f8c5006",
+ "metadata": {},
+ "source": [
+ "The two-dimensional steady Navier-Stokes equations is as follows:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5f08d997",
+ "metadata": {},
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d688635e",
+ "metadata": {},
+ "source": [
+ "The boundary conditions for this case are set as follows: velocity inlet is used for the inlet and upper and lower boundaries, and pressure outlet is used for the outlet. No-slip condition is used for the wall."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "37405ff8",
+ "metadata": {},
+ "source": [
+ "The input of PINNs in this case is spatial coordinates x, y. The output is u, v, p."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "23250593",
+ "metadata": {},
+ "source": [
+ "## Technology Path"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b093e904",
+ "metadata": {},
+ "source": [
+ "The process for solving this problem is as follows:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ae3bb7c1",
+ "metadata": {},
+ "source": [
+ "1.Dataset Construction.\n",
+ "2.Model Construction.\n",
+ "3.Calculate loss function.\n",
+ "4.Set optimizer.\n",
+ "5.Model Training.\n",
+ "6.Visualization."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c43b7dac",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import time\n",
+ "import numpy as np\n",
+ "import mindspore as ms\n",
+ "from mindflow.utils import load_yaml_config\n",
+ "from sciai.common import lbfgs_train"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "220c22cc",
+ "metadata": {},
+ "source": [
+ "The following src package can be downloaded from src."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a2ab10d1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from src import create_train_dataset, Visualization, Net, CalculateLoss\n",
+ "\n",
+ "seed = 2233\n",
+ "np.random.seed(seed)\n",
+ "ms.set_seed(seed)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "19f7802f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ms.context.set_context(device_target=args.device_target, device_id=args.device_id, mode=ms.PYNATIVE_MODE)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f731b7e4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config = load_yaml_config('./configs/NS_Cylinder_VW.yaml')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "42db1dde",
+ "metadata": {},
+ "source": [
+ "## Dataset Construction"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d4e0e58e",
+ "metadata": {},
+ "source": [
+ "The collocation points in this case are obtained through pointwise and imported through .txt file, while the boundary points are obtained through random sampling."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0180abc2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# create dataset\n",
+ "X_inlet, U_inlet, X_wall, U_wall, X_boundary, U_boundary, X_outlet, p_outlet, X, Volume = create_train_dataset(config)\n",
+ "rho = config['Equ_para']['rho']\n",
+ "miu = config['Equ_para']['miu']\n",
+ "weight_b = config['weight_b']\n",
+ "weight_f = config['weight_f']"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5bd2b92a",
+ "metadata": {},
+ "source": [
+ "## Model Construction"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "425b8542",
+ "metadata": {},
+ "source": [
+ "The case uses a fully connected network with a hidden layer depth of 5, a width of 64, and the activation function is tanh."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8793ef9f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define model\n",
+ "inputs = config['model']['inputs']\n",
+ "layers = config['model']['layers']\n",
+ "neurons = config['model']['neurons']\n",
+ "outputs = config['model']['outputs']\n",
+ "NN_architecture = [inputs] + layers*[neurons] + [outputs]\n",
+ "net = Net(NN_architecture)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fc8de115",
+ "metadata": {},
+ "source": [
+ "## Calculate loss function"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4a620514",
+ "metadata": {},
+ "source": [
+ "The loss function includes equation loss and boundary loss, which are combined."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "51334a61",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define loss\n",
+ "net_lb = CalculateLoss(net, X_inlet, U_inlet, X_wall, U_wall, X_boundary, U_boundary, X_outlet, p_outlet, Volume,\n",
+ " rho, miu, weight_b, weight_f)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "67dd28e5",
+ "metadata": {},
+ "source": [
+ "## Set optimizer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "98116615",
+ "metadata": {},
+ "source": [
+ "The case requires the second-order optimizer L-BFGS."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e42ab864",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define optimizer\n",
+ "max_iters = config['max_iters']\n",
+ "lbfgs_train(net_lb, (X,), max_iters)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9d79c7f7",
+ "metadata": {},
+ "source": [
+ "## Model Training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ba576a27",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "start_time = time.time()\n",
+ "loss = net_lb(X)\n",
+ "print(loss)\n",
+ "end_time = time.time()\n",
+ "print(\"Train time: {} s\".format(end_time - start_time))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e9d1b564",
+ "metadata": {},
+ "source": [
+ "## Visualization"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "25d3f2ee",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "Visualization(net, X_wall, config)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/MindFlow/applications/research/ns_cylinder_pinns/NS_Cylinder_CN.ipynb b/MindFlow/applications/research/ns_cylinder_pinns/NS_Cylinder_CN.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..cf287908a11394c380885f6d13e2c0d7591f994c
--- /dev/null
+++ b/MindFlow/applications/research/ns_cylinder_pinns/NS_Cylinder_CN.ipynb
@@ -0,0 +1,353 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "d884253a",
+ "metadata": {},
+ "source": [
+ "## 不可压粘性圆柱绕流问题"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d0757e4e",
+ "metadata": {},
+ "source": [
+ "## 环境安装"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7aa16c4d",
+ "metadata": {},
+ "source": [
+ "本案例要求Python3.7版本,MindSpore >= 2.0.0 以上版本,并需安装Sciai工具库。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2f0ae0af",
+ "metadata": {},
+ "source": [
+ "## 概述"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fecff9d1",
+ "metadata": {},
+ "source": [
+ "流体问题的求解一直是科学和工程领域的一大问题,其在航空航天、气象预测等领域至关重要。在过去的几十年里,此类问题的求解主要通过传统数值方法,包括有限差分法(FDM),有限体积法(FVM)和有限元法(FEM)等。但这些方法也存在一些局限性,包括依赖网格、反问题求解困难等。\n",
+ "\n",
+ "近年来,随着深度学习技术和计算资源的迅速发展,通过神经网络求解偏微分方程成为一大研究热点,并在许多领域取得了成功应用,是对传统范式的有力补充。\n",
+ "\n",
+ "不可压粘性圆柱绕流问题受Navier-Stokes方程控制,该方程描述了圆柱周围流体的运动规律。Navier-Stokes方程是一个非线性极强的方程组,本例涉及其二维定常形式,通过PINNs(Physics INformed Neural Networks)并结合体积加权的改进策略进行求解。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "99e4ee78",
+ "metadata": {},
+ "source": [
+ "## 问题描述"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "12d59411",
+ "metadata": {},
+ "source": [
+ "二维定常Navier-Stokes方程的形式如下:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "501afe3e",
+ "metadata": {},
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e0c291b7",
+ "metadata": {},
+ "source": [
+ "本案例的边界条件设置:入口处和上下边界采用速度入口,出口采用压力出口,圆柱物面使用无滑移条件。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "071402aa",
+ "metadata": {},
+ "source": [
+ "本案例PINNs的输入为空间坐标x,y;输出为u,v,p。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fdb8a09f",
+ "metadata": {},
+ "source": [
+ "## 技术路径"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "be7c8fc9",
+ "metadata": {},
+ "source": [
+ "该问题求解的具体流程如下:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6c91528e",
+ "metadata": {},
+ "source": [
+ "1.创建数据集;\n",
+ "2.构建模型;\n",
+ "3.计算损失函数;\n",
+ "4.设置优化器;\n",
+ "5.模型训练;\n",
+ "6.可视化。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "eb832a1d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import time\n",
+ "import numpy as np\n",
+ "import mindspore as ms\n",
+ "from mindflow.utils import load_yaml_config\n",
+ "from sciai.common import lbfgs_train"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ae7e5e17",
+ "metadata": {},
+ "source": [
+ "下述src包可以在src下载。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6d51a6b8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from src import create_train_dataset, Visualization, Net, CalculateLoss\n",
+ "\n",
+ "seed = 2233\n",
+ "np.random.seed(seed)\n",
+ "ms.set_seed(seed)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1a6b7832",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ms.context.set_context(device_target=args.device_target, device_id=args.device_id, mode=ms.PYNATIVE_MODE)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "06132008",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config = load_yaml_config('./configs/NS_Cylinder_VW.yaml')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "36018775",
+ "metadata": {},
+ "source": [
+ "## 创建数据集"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "426dd107",
+ "metadata": {},
+ "source": [
+ "本案例的配置点通过pointwise获得后通过.txt文件导入,边界点通过随机采样获得。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4cbd704c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# create dataset\n",
+ "X_inlet, U_inlet, X_wall, U_wall, X_boundary, U_boundary, X_outlet, p_outlet, X, Volume = create_train_dataset(config)\n",
+ "rho = config['Equ_para']['rho']\n",
+ "miu = config['Equ_para']['miu']\n",
+ "weight_b = config['weight_b']\n",
+ "weight_f = config['weight_f']"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3b188d63",
+ "metadata": {},
+ "source": [
+ "## 构建模型"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cf3fff9e",
+ "metadata": {},
+ "source": [
+ "本案例使用全连接网络,隐藏层深度为5,宽度为64,激活函数为tanh。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "dc23702b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define model\n",
+ "inputs = config['model']['inputs']\n",
+ "layers = config['model']['layers']\n",
+ "neurons = config['model']['neurons']\n",
+ "outputs = config['model']['outputs']\n",
+ "NN_architecture = [inputs] + layers*[neurons] + [outputs]\n",
+ "net = Net(NN_architecture)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4a55cf93",
+ "metadata": {},
+ "source": [
+ "## 计算损失函数"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "36619ee1",
+ "metadata": {},
+ "source": [
+ "损失函数包括方程损失和边界损失,两者叠加。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "183f26b6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define loss\n",
+ "net_lb = CalculateLoss(net, X_inlet, U_inlet, X_wall, U_wall, X_boundary, U_boundary, X_outlet, p_outlet, Volume,\n",
+ " rho, miu, weight_b, weight_f)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "60daca5a",
+ "metadata": {},
+ "source": [
+ "## 设置优化器"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "98b5446d",
+ "metadata": {},
+ "source": [
+ "本案例需要使用二阶优化器L-BFGS。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "046f7b6e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define optimizer\n",
+ "max_iters = config['max_iters']\n",
+ "lbfgs_train(net_lb, (X,), max_iters)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d8448774",
+ "metadata": {},
+ "source": [
+ "## 模型训练"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0ecb63ba",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "start_time = time.time()\n",
+ "loss = net_lb(X)\n",
+ "print(loss)\n",
+ "end_time = time.time()\n",
+ "print(\"Train time: {} s\".format(end_time - start_time))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1a50d6a2",
+ "metadata": {},
+ "source": [
+ "## 可视化"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c9245604",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "Visualization(net, X_wall, config)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/MindFlow/applications/research/ns_cylinder_pinns/Readme.md b/MindFlow/applications/research/ns_cylinder_pinns/Readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..eac85ac4cbd0c7e060461c44406ab293556083d5
--- /dev/null
+++ b/MindFlow/applications/research/ns_cylinder_pinns/Readme.md
@@ -0,0 +1,23 @@
+# Viscous incompressible flow over a circular cylinder
+
+# Overview
+
+The viscous incompressible flow over a circular cylinder is governed by the Navier-Stokes equations, which describe the motion of the fluid around the cylinder. The Navier-Stokes equations are a highly nonlinear system of equations. In this case, the two-dimensional steady form of these equations is solved using Physics-Informed Neural Networks (PINNs) combined with volume weighting methods.
+
+## QuickStart
+
+You can download dataset from dataset for model evaluation. Save these dataset at ./dataset.
+
+## Run Option 1: Call train.py from command line
+
+python train.py --mode PYNATIVE --device_target Ascend --device_id 0 --config_file_path ./configs/NS_Cylinder_VW.yaml
+
+where:
+--mode is the running mode. 'PYNATIVE' indicates dynamic graph mode. You can refer to MindSpore official website for details.
+--device_target indicates the computing platform. You can choose 'Ascend'.
+--device_id indicates the index of NPU or GPU. Default 0.
+--config_file_path indicates the path of the parameter file. Default './configs/NS_Cylinder_VW.yaml';
+
+## Run Option 2: Run Jupyter Notebook
+
+You can use Chinese or English Jupyter Notebook to run the training and evaluation code line-by-line.
\ No newline at end of file
diff --git a/MindFlow/applications/research/ns_cylinder_pinns/Readme_CN.md b/MindFlow/applications/research/ns_cylinder_pinns/Readme_CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..2ca13012e137d348886652989911840291db1e21
--- /dev/null
+++ b/MindFlow/applications/research/ns_cylinder_pinns/Readme_CN.md
@@ -0,0 +1,22 @@
+# 不可压粘性圆柱绕流问题
+
+# 概述
+
+不可压粘性圆柱绕流问题受Navier-Stokes方程控制,该方程描述了圆柱周围流体的运动规律。Navier-Stokes方程是一个非线性极强的方程组,本例涉及其二维定常形式,通过PINNs(Physics INformed Neural Networks)并结合体积加权的改进策略进行求解。
+
+# 快速开始
+
+从dataset中下载训练及测试数据,并保存在./dataset目录下。
+
+# 训练方式一:在命令行调用train.py脚本
+
+python train.py --mode PYNATIVE --device_target Ascend --device_id 0 --config_file_path ./configs/NS_Cylinder_VW.yaml
+
+其中,--mode表示运行的模式,'PYNATIVE'表示动态图模式,详见MindSpore官网;
+--device_target表示使用的计算平台类型,选择'Ascend';
+--device_id表示使用的计算卡编号,可按照实际情况填写,默认值0;
+--config_file_path表示参数文件的路径,默认值'./configs/NS_Cylinder_VW.yaml';
+
+## 训练方式二:运行Jupyter Notebook
+
+您可以使用中文版和英文版Jupyter Noterbook逐行运行代码,进行训练和测试。
diff --git a/MindFlow/applications/research/ns_cylinder_pinns/configs/NS_Cylinder_VW.yaml b/MindFlow/applications/research/ns_cylinder_pinns/configs/NS_Cylinder_VW.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5a0e112bbfea68527dafa328078423a283b4fa58
--- /dev/null
+++ b/MindFlow/applications/research/ns_cylinder_pinns/configs/NS_Cylinder_VW.yaml
@@ -0,0 +1,28 @@
+Boundary:
+ N_boundary: 50
+ N_outlet: 75
+ N_wall: 201
+ R_cylinder: 0.5
+ R_far: 15
+ alpha: 0.0
+ p_outlet: 0.0
+ theta_b: 0.0
+ theta_e: 6.28318530718
+ u_wall: 0.0
+ v_wall: 0.0
+ x_range: 20
+ y_range: 15
+Equ_para:
+ miu: 0.025
+ rho: 1.0
+Volume: 476.71458676443
+bandwidth: 0.03
+dataset_path: ./dataset
+max_iters: 100000
+model:
+ inputs: 2
+ layers: 5
+ neurons: 64
+ outputs: 3
+weight_b: 3
+weight_f: 50000
diff --git a/MindFlow/applications/research/ns_cylinder_pinns/dataset/cp_re40.mat b/MindFlow/applications/research/ns_cylinder_pinns/dataset/cp_re40.mat
new file mode 100644
index 0000000000000000000000000000000000000000..76adf96207bd7cb8111cff2710fab7a0fb3ebbcf
Binary files /dev/null and b/MindFlow/applications/research/ns_cylinder_pinns/dataset/cp_re40.mat differ
diff --git a/MindFlow/applications/research/ns_cylinder_pinns/dataset/mesh_point.mat b/MindFlow/applications/research/ns_cylinder_pinns/dataset/mesh_point.mat
new file mode 100644
index 0000000000000000000000000000000000000000..f490ddc0cfa760e8a50771414fe7e45b266bb3f3
Binary files /dev/null and b/MindFlow/applications/research/ns_cylinder_pinns/dataset/mesh_point.mat differ
diff --git a/MindFlow/applications/research/ns_cylinder_pinns/src/__init__.py b/MindFlow/applications/research/ns_cylinder_pinns/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a095662d6cac11ac0ac94c57339556710b192d4
--- /dev/null
+++ b/MindFlow/applications/research/ns_cylinder_pinns/src/__init__.py
@@ -0,0 +1,10 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Thu Oct 24 16:02:54 2024
+
+@author: songjiahao
+"""
+
+from .dataset import create_train_dataset, create_test_dataset
+from .utils import Visualization
+from .model import Net, CalculateLoss
diff --git a/MindFlow/applications/research/ns_cylinder_pinns/src/dataset.py b/MindFlow/applications/research/ns_cylinder_pinns/src/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4581c97103dc1c971beefd7a899ed9eae1e70b5
--- /dev/null
+++ b/MindFlow/applications/research/ns_cylinder_pinns/src/dataset.py
@@ -0,0 +1,70 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Thu Oct 24 12:10:31 2024
+
+@author: songjiahao
+"""
+
+import os
+import numpy as np
+import mindspore as ms
+from mindspore import ops
+from scipy.stats import gaussian_kde
+from scipy.io import loadmat
+
+def create_train_dataset(config):
+ '''Boundary_points'''
+ r_far = config['Boundary']['R_far']
+ r_cylinder = config['Boundary']['R_cylinder']
+ n_wall = config['Boundary']['N_wall']
+ theta_b = config['Boundary']['theta_b']
+ theta_e = config['Boundary']['theta_e']
+ theta = ms.Tensor(np.linspace(theta_b, theta_e, n_wall).reshape(-1, 1), ms.float32)
+ x_b1 = ops.concat([r_far * ops.cos(theta), r_far * ops.sin(theta)], 1)
+ x_b2 = ops.concat([r_cylinder * ops.cos(theta), r_cylinder * ops.sin(theta)], 1)
+ alpha = ms.Tensor(config['Boundary']['alpha'])
+ x_inlet = x_b1[int((n_wall-1)/4):int(n_wall-(n_wall-1)/4), :]
+ u_inlet = ops.concat([(ops.cos(alpha)).reshape(-1, 1), (ops.sin(alpha)).reshape(-1, 1)], 1)
+ x_wall = x_b2
+ u_wall = ops.concat([ms.Tensor([[config['Boundary']['u_wall']]], ms.float32),
+ ms.Tensor([[config['Boundary']['v_wall']]], ms.float32)], 1)
+ n_boundary = config['Boundary']['N_boundary']
+ x_up = np.concatenate([config['Boundary']['x_range'] * np.random.rand(n_boundary, 1),
+ config['Boundary']['y_range'] * np.ones((n_boundary, 1))], 1)
+ x_down = np.concatenate([config['Boundary']['x_range'] * np.random.rand(n_boundary, 1),
+ config['Boundary']['y_range']*-1*np.ones((n_boundary, 1))], 1)
+ x_boundary = np.concatenate([x_up, x_down], 0)
+ x_boundary = ms.Tensor(x_boundary, ms.float32)
+ u_boundary = ops.concat([(ops.cos(alpha)).reshape(-1, 1), (ops.sin(alpha)).reshape(-1, 1)], 1)
+ n_outlet = config['Boundary']['N_outlet']
+ x_outlet = np.concatenate([config['Boundary']['x_range'] * np.ones((n_outlet, 1)),
+ config['Boundary']['y_range'] * (2 * np.random.rand(n_outlet, 1) - 1)], 1)
+ x_outlet = ms.Tensor(x_outlet, ms.float32)
+ p_outlet = ms.Tensor([[config['Boundary']['p_outlet']]], ms.float32)
+ dataset_path = config['dataset_path']
+ file_name = 'mesh_point.mat'
+ file_path = os.path.join(dataset_path, file_name)
+ mesh = loadmat(file_path)
+ mesh = mesh['data']
+ x = ms.Tensor(mesh, ms.float32)
+ kde = gaussian_kde(mesh.T, bw_method=config['bandwidth'])
+ xx_f = mesh[:, [0]]
+ yy_f = mesh[:, [1]]
+ grid = np.vstack([xx_f.ravel(), yy_f.ravel()])
+ density = kde(grid)
+ recip = 1 / density
+ recip_norm = recip / recip.sum()
+ volume = (config['Volume'] * recip_norm).reshape(-1, 1)
+ volume = ms.Tensor(volume, ms.float32)
+ return x_inlet, u_inlet, x_wall, u_wall, x_boundary, u_boundary, x_outlet, p_outlet, x, volume
+
+def create_test_dataset(config):
+ '''cp_wall'''
+ dataset_path = config['dataset_path']
+ file_name = 'cp_re40.mat'
+ file_path = os.path.join(dataset_path, file_name)
+ cp = loadmat(file_path)
+ cp = cp['data']
+ x_wall = cp[:, [0]]
+ cp_wall = cp[:, [1]]
+ return x_wall, cp_wall
diff --git a/MindFlow/applications/research/ns_cylinder_pinns/src/model.py b/MindFlow/applications/research/ns_cylinder_pinns/src/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..38e681057fc7349e6e7011bc36996741625eb877
--- /dev/null
+++ b/MindFlow/applications/research/ns_cylinder_pinns/src/model.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Thu Oct 24 15:12:16 2024
+
+@author: songjiahao
+"""
+
+import mindspore as ms
+from mindspore import nn, ops, jit
+from sciai.architecture import MSE
+
+class ActFunc(nn.Cell):
+ '''ActFun'''
+ def __init__(self):
+ pass
+
+ def construct(self, x):
+ return ops.tanh(x)
+
+class Net(nn.Cell):
+ '''FCNN'''
+ def __init__(self, layer, bn=False, drop=False):
+ super(Net, self).__init__()
+ self.layer = layer
+ self.bn = bn
+ layers = []
+ for i in range(len(self.layer) - 1):
+ layers.append(nn.Dense(self.layer[i], self.layer[i + 1]))
+ if i != (len(self.layer) - 2):
+ if bn:
+ layers.append(nn.BatchNorm1d(self.layer[i + 1]))
+ if drop:
+ layers.append(nn.Dropout(keep_prob=1 - drop))
+ layers.append(ActFunc())
+ self.net = nn.SequentialCell(layers)
+
+ def construct(self, x):
+ return self.net(x)
+
+class CalculateLoss(nn.Cell):
+ '''CalculateLoss'''
+ def __init__(self, network, x_inlet, u_inlet, x_wall, u_wall, x_boundary, u_boundary, x_outlet, p_outlet, volume,
+ rho, miu, weight_b, weight_f):
+ super().__init__()
+ self.model = network
+ self.x_inlet = x_inlet
+ self.u_inlet = u_inlet
+ self.x_wall = x_wall
+ self.u_wall = u_wall
+ self.x_boundary = x_boundary
+ self.u_boundary = u_boundary
+ self.x_outlet = x_outlet
+ self.p_outlet = p_outlet
+ self.volume = volume
+ self.rho = rho
+ self.miu = miu
+ self.weight_b = weight_b
+ self.weight_f = weight_f
+ self.mse = MSE()
+
+ def construct(self, x):
+ '''Loss'''
+ u_inlet_pred = self.model(self.x_inlet)
+ mseb1 = self.mse(u_inlet_pred[:, [0]] - self.u_inlet[:, [0]]) + \
+ self.mse(u_inlet_pred[:, [1]] - self.u_inlet[:, [1]])
+ u_wall_pred = self.model(self.x_wall)
+ mseb2 = self.mse(u_wall_pred[:, [0]] - self.u_wall[:, [0]]) + \
+ self.mse(u_wall_pred[:, [1]] - self.u_wall[:, [1]])
+ u_boundary_pred = self.model(self.x_boundary)
+ mseb3 = self.mse(u_boundary_pred[:, [0]] - self.u_boundary[:, [0]]) + \
+ self.mse(u_boundary_pred[:, [1]] - self.u_boundary[:, [1]])
+ u_outlet_pred = self.model(self.x_outlet)
+ mseb4 = self.mse(u_outlet_pred[:, [2]] - self.p_outlet)
+ mseb = mseb1 + mseb2 + mseb3 + mseb4
+
+ def forward_out1(xi):
+ return self.model(xi)[:, 0:1] # 第一个输出
+
+ @jit(compile_once=True)
+ def gradient_out1(xi):
+ return ms.grad(forward_out1)(xi)
+
+ def forward_out2(xi):
+ return self.model(xi)[:, 1:2] # 第一个输出
+
+ @jit(compile_once=True)
+ def gradient_out2(xi):
+ return ms.grad(forward_out2)(xi)
+
+ def forward_out3(xi):
+ return self.model(xi)[:, 2:3] # 第一个输出
+
+ @jit(compile_once=True)
+ def gradient_out3(xi):
+ return ms.grad(forward_out3)(xi)
+
+ # 计算二阶导数
+ def second_order_grad(xi):
+ u_xx = ms.grad(lambda y: gradient_out1(y)[:, 0:1])(xi)[:, 0:1] # u_xx
+ u_yy = ms.grad(lambda y: gradient_out1(y)[:, 1:2])(xi)[:, 1:2] # u_yy
+
+ v_xx = ms.grad(lambda y: gradient_out2(y)[:, 0:1])(xi)[:, 0:1] # u_xx
+ v_yy = ms.grad(lambda y: gradient_out2(y)[:, 1:2])(xi)[:, 1:2] # u_yy
+ return u_xx, u_yy, v_xx, v_yy
+
+ us = self.model(x)
+ u = us[:, [0]]
+ v = us[:, [1]]
+ u_xs = gradient_out1(x)
+ v_xs = gradient_out2(x)
+ p_xs = gradient_out3(x)
+ u_xx, u_yy, v_xx, v_yy = second_order_grad(x)
+ u_x = u_xs[:, [0]]
+ u_y = u_xs[:, [1]]
+ v_x = v_xs[:, [0]]
+ v_y = v_xs[:, [1]]
+ p_x = p_xs[:, [0]]
+ p_y = p_xs[:, [1]]
+
+ res_m = u_x + v_y
+ res_u = self.rho*(u*u_x + v*u_y) + p_x - self.miu*(u_xx + u_yy)
+ res_v = self.rho*(u*v_x + v*v_y) + p_y -self. miu*(v_xx + v_yy)
+ msef = ((res_m**2 + res_u**2 + res_v**2)*self.volume**2).sum() / (self.volume**2).sum()
+
+ loss = self.weight_b*mseb + self.weight_f*msef
+ print(f"Boundary Loss: {self.weight_b*mseb.asnumpy()}, PDE Loss: {(self.weight_f**msef).asnumpy()}")
+ return loss
diff --git a/MindFlow/applications/research/ns_cylinder_pinns/src/utils.py b/MindFlow/applications/research/ns_cylinder_pinns/src/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec2249111b58e0586cdd5f6d85630a8a414e6377
--- /dev/null
+++ b/MindFlow/applications/research/ns_cylinder_pinns/src/utils.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Thu Oct 24 15:30:02 2024
+
+@author: songjiahao
+"""
+import os
+import matplotlib.pyplot as plt
+from scipy.io import loadmat
+
+def Visualization(model, x_wall, config):
+ '''Visualization'''
+ p_wall_pred = model(x_wall)[:, [2]].asnumpy()
+ cp_wall_pred = 2*p_wall_pred
+ dataset_path = config['dataset_path']
+ file_name = 'cp_re40.mat'
+ file_path = os.path.join(dataset_path, file_name)
+ cp = loadmat(file_path)
+ cp = cp['data']
+ plt.figure()
+ plt.plot(cp[:, [0]], cp[:, [1]], 'r', LineWidth=3.5, label='Ref')
+ plt.plot(X_wall[:, [0]], cp_wall_pred, 'k--', LineWidth=3.5, label='VW-PINNs')
+ plt.xticks(fontname="Times New Roman", fontsize=20)
+ plt.yticks(fontname="Times New Roman", fontsize=20)
+ plt.xlabel('x', fontdict={'fontstyle': 'italic'}, fontname="Times New Roman", fontsize=22)
+ plt.ylabel('$C_p$', fontdict={'fontstyle': 'italic'}, fontname="Times New Roman", fontsize=22)
+ legend_font = {'family': 'Times New Roman',
+ 'size': 18}
+ plt.legend(prop=legend_font)
+ plt.show()
diff --git a/MindFlow/applications/research/ns_cylinder_pinns/train.py b/MindFlow/applications/research/ns_cylinder_pinns/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..c409dbe0e3083e37d3055e2e050f4ee55ad8a3b6
--- /dev/null
+++ b/MindFlow/applications/research/ns_cylinder_pinns/train.py
@@ -0,0 +1,73 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Thu Oct 24 16:05:40 2024
+
+@author: songjiahao
+"""
+
+import argparse
+import os
+import time
+import numpy as np
+import mindspore as ms
+from mindflow.utils import load_yaml_config
+from sciai.common import lbfgs_train
+
+from src import create_train_dataset, Visualization, Net, CalculateLoss
+
+seed = 2233
+np.random.seed(seed)
+ms.set_seed(seed)
+
+def parse_args():
+ '''Parse input args'''
+ parser = argparse.ArgumentParser(description="NS cylinder train")
+ parser.add_argument("--config_file_path", type=str, default="./configs/NS_Cylinder_VW.yaml")
+ parser.add_argument("--device_target", type=str, default="Ascend")
+ parser.add_argument("--device_id", type=int, default=0)
+ parser.add_argument("--mode", type=str, default="PYNATIVE")
+
+ input_args = parser.parse_args()
+ return input_args
+
+def train(train_args):
+ '''Train and evaluate the network'''
+ # load configurations
+ config = load_yaml_config(train_args.config_file_path)
+
+ # create dataset
+ x_inlet, u_inlet, x_wall, u_wall, x_boundary, u_boundary, x_outlet, \
+ p_outlet, x, volume = create_train_dataset(config)
+ rho = config['Equ_para']['rho']
+ miu = config['Equ_para']['miu']
+ weight_b = config['weight_b']
+ weight_f = config['weight_f']
+
+ # define model
+ inputs = config['model']['inputs']
+ layers = config['model']['layers']
+ neurons = config['model']['neurons']
+ outputs = config['model']['outputs']
+ nn_architecture = [inputs] + layers*[neurons] + [outputs]
+ net = Net(nn_architecture)
+
+ # define optimizer
+ net_lb = CalculateLoss(net, x_inlet, u_inlet, x_wall, u_wall, x_boundary, u_boundary, x_outlet, p_outlet, volume,
+ rho, miu, weight_b, weight_f)
+ max_iters = config['max_iters']
+ lbfgs_train(net_lb, (x,), max_iters)
+
+ # train
+ loss = net_lb(x)
+ print(loss)
+ Visualization(net, x_wall, config)
+
+if __name__ == '__main__':
+ print("pid:", os.getpid())
+ start_time = time.time()
+ args = parse_args()
+ ms.context.set_context(device_target=args.device_target, device_id=args.device_id, mode=ms.PYNATIVE_MODE)
+ train(args)
+ end_time = time.time()
+ print("Train time: {} s".format(end_time - start_time))
+
\ No newline at end of file