!> MPIの初期化・集約を抽象化し、非MPIビルドでは単一ランク動作へフォールバックする。 module bem_mpi use bem_kinds, only: dp, i32 implicit none private type :: mpi_context integer(i32) :: rank = 0_i32 integer(i32) :: size = 1_i32 logical :: enabled = .false. #ifdef USE_MPI logical :: initialized_here = .false. #endif end type mpi_context public :: mpi_context public :: mpi_initialize public :: mpi_shutdown public :: mpi_is_root public :: mpi_world_size public :: mpi_get_rank_size public :: mpi_split_count public :: mpi_allreduce_sum_real_dp_array public :: mpi_allreduce_sum_real_dp_scalar public :: mpi_allreduce_min_real_dp_array public :: mpi_allreduce_max_real_dp_array public :: mpi_allreduce_sum_i32_array public :: mpi_allreduce_sum_i32_scalar public :: mpi_world_barrier contains !> MPIを初期化して rank / size を取得する。非MPIビルドでは単一ランクを返す。 subroutine mpi_initialize(ctx) type(mpi_context), intent(out) :: ctx #ifdef USE_MPI include 'mpif.h' logical :: is_initialized integer :: ierr integer :: rank_int, size_int #endif ctx = mpi_context() #ifdef USE_MPI call MPI_Initialized(is_initialized, ierr) if (.not. is_initialized) then call MPI_Init(ierr) ctx%initialized_here = .true. end if call MPI_Comm_rank(MPI_COMM_WORLD, rank_int, ierr) call MPI_Comm_size(MPI_COMM_WORLD, size_int, ierr) ctx%rank = int(rank_int, i32) ctx%size = int(size_int, i32) ctx%enabled = (ctx%size > 1_i32) #endif ! 非MPIビルド時、または MPI 実行系が 1 rank としか見えていない場合でも、 ! launcher 環境変数から rank/size を補完して root 専用ログの重複を避ける。 if (ctx%size <= 1_i32) call infer_launcher_rank_size(ctx) end subroutine mpi_initialize !> `mpi_initialize` が実際に初期化した場合のみ `MPI_Finalize` を呼ぶ。 subroutine mpi_shutdown(ctx) type(mpi_context), intent(inout) :: ctx #ifdef USE_MPI include 'mpif.h' logical :: is_finalized integer :: ierr if (ctx%initialized_here) then call MPI_Finalized(is_finalized, ierr) if (.not. is_finalized) call MPI_Finalize(ierr) ctx%initialized_here = .false. end if #endif ctx%rank = 0_i32 ctx%size = 1_i32 ctx%enabled = .false. end subroutine mpi_shutdown !> root rank (rank=0) かどうかを返す。 logical function mpi_is_root(ctx) type(mpi_context), intent(in) :: ctx mpi_is_root = (ctx%rank == 0_i32) end function mpi_is_root !> MPI world size を返す(size<=0 は 1 へ補正)。 integer(i32) function mpi_world_size(ctx) type(mpi_context), intent(in), optional :: ctx mpi_world_size = 1_i32 if (present(ctx)) mpi_world_size = max(1_i32, ctx%size) end function mpi_world_size !> `mpi_context` から rank/size を取得する。未指定時は単一rank(0/1)。 subroutine mpi_get_rank_size(rank, size, ctx) integer(i32), intent(out) :: rank, size type(mpi_context), intent(in), optional :: ctx rank = 0_i32 size = 1_i32 if (present(ctx)) then rank = ctx%rank size = max(1_i32, ctx%size) end if if (rank < 0_i32 .or. rank >= size) then error stop 'mpi_get_rank_size detected an invalid rank/size pair.' end if end subroutine mpi_get_rank_size !> 非MPIビルド時に launcher 環境変数から rank / size を補完する。 subroutine infer_launcher_rank_size(ctx) type(mpi_context), intent(inout) :: ctx logical :: found call try_launcher_env_pair(ctx, 'OMPI_COMM_WORLD_RANK', 'OMPI_COMM_WORLD_SIZE', found) if (found) return call try_launcher_env_pair(ctx, 'PMI_RANK', 'PMI_SIZE', found) if (found) return call try_launcher_env_pair(ctx, 'MV2_COMM_WORLD_RANK', 'MV2_COMM_WORLD_SIZE', found) end subroutine infer_launcher_rank_size !> rank/size の環境変数ペアを解釈できたときだけ `ctx` を更新する。 subroutine try_launcher_env_pair(ctx, rank_name, size_name, found) type(mpi_context), intent(inout) :: ctx character(len=*), intent(in) :: rank_name, size_name logical, intent(out) :: found integer(i32) :: rank_value, size_value logical :: has_rank, has_size found = .false. call read_env_i32(rank_name, rank_value, has_rank) call read_env_i32(size_name, size_value, has_size) if (.not. has_rank .or. .not. has_size) return if (size_value <= 0_i32) return if (rank_value < 0_i32 .or. rank_value >= size_value) return ctx%rank = rank_value ctx%size = size_value ctx%enabled = (ctx%size > 1_i32) found = .true. end subroutine try_launcher_env_pair !> 整数環境変数を読み取る。未設定や parse 失敗時は `found=.false.`。 subroutine read_env_i32(name, value, found) character(len=*), intent(in) :: name integer(i32), intent(out) :: value logical, intent(out) :: found integer :: status, length, ios character(len=64) :: raw value = 0_i32 found = .false. raw = '' call get_environment_variable(name, raw, length=length, status=status) if (status /= 0 .or. length <= 0 .or. length > len(raw)) return read (raw(:length), *, iostat=ios) value if (ios /= 0) return found = .true. end subroutine read_env_i32 !> 総数 `total_count` をrankへ均等分割したときの局所個数を返す。 integer(i32) function mpi_split_count(total_count, rank, size) result(local_count) integer(i32), intent(in) :: total_count, rank, size integer(i32) :: base_count, n_remainder if (total_count < 0_i32) error stop 'mpi_split_count requires total_count >= 0.' if (size <= 0_i32) error stop 'mpi_split_count requires size > 0.' if (rank < 0_i32 .or. rank >= size) error stop 'mpi_split_count rank out of range.' base_count = total_count/size n_remainder = modulo(total_count, size) local_count = base_count if (rank < n_remainder) local_count = local_count + 1_i32 end function mpi_split_count !> 倍精度配列の総和Allreduceをin-placeで実行する。 subroutine mpi_allreduce_sum_real_dp_array(ctx, values) type(mpi_context), intent(in) :: ctx real(dp), intent(inout) :: values(:) #ifdef USE_MPI include 'mpif.h' real(dp), allocatable :: recvbuf(:) integer :: ierr if (.not. ctx%enabled) return allocate (recvbuf(size(values))) call MPI_Allreduce(values, recvbuf, size(values), MPI_DOUBLE_PRECISION, MPI_SUM, MPI_COMM_WORLD, ierr) values = recvbuf #endif end subroutine mpi_allreduce_sum_real_dp_array !> 倍精度スカラの総和Allreduceをin-placeで実行する。 subroutine mpi_allreduce_sum_real_dp_scalar(ctx, value) type(mpi_context), intent(in) :: ctx real(dp), intent(inout) :: value #ifdef USE_MPI include 'mpif.h' real(dp) :: recvval integer :: ierr if (.not. ctx%enabled) return call MPI_Allreduce(value, recvval, 1, MPI_DOUBLE_PRECISION, MPI_SUM, MPI_COMM_WORLD, ierr) value = recvval #endif end subroutine mpi_allreduce_sum_real_dp_scalar !> 倍精度配列の最小値Allreduceをin-placeで実行する。 subroutine mpi_allreduce_min_real_dp_array(ctx, values) type(mpi_context), intent(in) :: ctx real(dp), intent(inout) :: values(:) #ifdef USE_MPI include 'mpif.h' real(dp), allocatable :: recvbuf(:) integer :: ierr if (.not. ctx%enabled) return allocate (recvbuf(size(values))) call MPI_Allreduce(values, recvbuf, size(values), MPI_DOUBLE_PRECISION, MPI_MIN, MPI_COMM_WORLD, ierr) values = recvbuf #endif end subroutine mpi_allreduce_min_real_dp_array !> 倍精度配列の最大値Allreduceをin-placeで実行する。 subroutine mpi_allreduce_max_real_dp_array(ctx, values) type(mpi_context), intent(in) :: ctx real(dp), intent(inout) :: values(:) #ifdef USE_MPI include 'mpif.h' real(dp), allocatable :: recvbuf(:) integer :: ierr if (.not. ctx%enabled) return allocate (recvbuf(size(values))) call MPI_Allreduce(values, recvbuf, size(values), MPI_DOUBLE_PRECISION, MPI_MAX, MPI_COMM_WORLD, ierr) values = recvbuf #endif end subroutine mpi_allreduce_max_real_dp_array !> 32bit整数配列の総和Allreduceをin-placeで実行する。 subroutine mpi_allreduce_sum_i32_array(ctx, values) type(mpi_context), intent(in) :: ctx integer(i32), intent(inout) :: values(:) #ifdef USE_MPI include 'mpif.h' integer, allocatable :: sendbuf(:), recvbuf(:) integer :: ierr if (.not. ctx%enabled) return allocate (sendbuf(size(values)), recvbuf(size(values))) sendbuf = int(values, kind=kind(0)) call MPI_Allreduce(sendbuf, recvbuf, size(values), MPI_INTEGER, MPI_SUM, MPI_COMM_WORLD, ierr) values = int(recvbuf, kind=i32) #endif end subroutine mpi_allreduce_sum_i32_array !> 32bit整数スカラの総和Allreduceをin-placeで実行する。 subroutine mpi_allreduce_sum_i32_scalar(ctx, value) type(mpi_context), intent(in) :: ctx integer(i32), intent(inout) :: value #ifdef USE_MPI include 'mpif.h' integer :: sendval, recvval, ierr if (.not. ctx%enabled) return sendval = int(value, kind=kind(0)) call MPI_Allreduce(sendval, recvval, 1, MPI_INTEGER, MPI_SUM, MPI_COMM_WORLD, ierr) value = int(recvval, kind=i32) #endif end subroutine mpi_allreduce_sum_i32_scalar !> 全rankの同期ポイント。 subroutine mpi_world_barrier(ctx) type(mpi_context), intent(in) :: ctx #ifdef USE_MPI include 'mpif.h' integer :: ierr if (.not. ctx%enabled) return call MPI_Barrier(MPI_COMM_WORLD, ierr) #endif end subroutine mpi_world_barrier end module bem_mpi