function [testY_Pred, rmse, net, Xmu, Xsig, Ymu, Ysig]=learning_PMWG_function_LSTM1data(X_train,Y_train, X_test, Y_test)
% ネットワーク学習に関する設定
training = true;
params.plot = 'training-progress'; % 'training-progress' / 'none'
params.executionenv = 'cpu';

%% 過渡応答データの準備
% 平均・分散の計算
Xmu = mean(X_train,2);
Xsig = std(X_train,0,2);
Ymu = mean(Y_train,2);
Ysig = std(Y_train,0,2);
% 平均・分散を用いてデータを標準化
std_trainX = (X_train - Xmu) ./ Xsig;
std_trainY = (Y_train - Ymu) ./ Ysig;
std_testX  = (X_test - Xmu) ./ Xsig;
std_testY  = (Y_test - Ymu) ./ Ysig;

%% 過渡応答用ネットワークの作成
% 入出力の次元
numFeatures = size(X_train,1);
numResponses = size(Y_train,1);
% LSTMノードの数
numHiddenUnits = 50;

% ネットワークの定義
layers = [ ...
    sequenceInputLayer(numFeatures,"Name","Sequenceinput")
    lstmLayer(numHiddenUnits,"Name","lstm","OutputMode","sequence")
    fullyConnectedLayer(numResponses,"Name","fc2")
    regressionLayer("Name","regressionoutput")];
lgraph = layerGraph(layers);

% 学習パラメータの設定
maxEpochs = 50;% 損失関数がほぼ収束するために必要な値。今回は10と仮定
miniBatchSize = 64; % 学習データセットをサブセットへ分割する際のデータ数。2のn乗の値とすることが多い。今回は64を採用
options = trainingOptions('adam', ...
    'MaxEpochs',maxEpochs, ...
    'MiniBatchSize',miniBatchSize, ...
    'InitialLearnRate',0.01, ...
    'ValidationFrequency',50,... 
    'Shuffle','every-epoch', ...
    'Plots',params.plot,...
    'Verbose',0);
% ネットワークの学習
if training == true
    disp('train LSTM model...')
    % ネットワークの学習
    [net, info] = trainNetwork(std_trainX, std_trainY, lgraph, options);
else
    % trainingフラグがFalseであればネットワークの読み込み
    load('net.mat');
end

%% 推論(学習したモデルを実行し数値を予測)と評価
% trainは学習用。testは検証用。
[~, std_trainY_Pred] = predictAndUpdateState(net, std_trainX, 'MiniBatchSize',1);
[~, std_testY_Pred]  = predictAndUpdateState(net, std_testX,  'MiniBatchSize',1);
% 標準化して扱ってきたので、元のスケールに戻す。
trainY_Pred = (std_trainY_Pred .* Ysig) + Ymu;
testY_Pred  = (std_testY_Pred  .* Ysig) + Ymu;
% RMSEを計算
rmse = sqrt(mean((Y_test - testY_Pred)'.^2));

%% 結果をグラフで表示
modeltype = "LSTM";
% グラフ＆評価結果(左側：学習用データ、右側：評価用データ)
figure('Name', '推論'),
for n = 1 : size(Y_train, 1)
    subplot(size(Y_train, 1), 2, 2*n-1)
    plot(Y_train(n, :),'DisplayName','standardizedYtrain2');
    hold on;
    plot(trainY_Pred(n, :),'DisplayName','YPred');
    legend('真値', '推論値')
    grid on
    hold off;
    if n == 1
        title(["【学習：SOC】データ番号："+ num2str(n)]);
        xlabel('Index [-]')
        ylabel('SOC [%]')
    elseif n == 2
        title(["【学習：Voltage】データ番号："+ num2str(n)]);
        xlabel('Index [-]')
        ylabel('Voltage [V]')
    end

    subplot(size(Y_train, 1), 2, 2*n)
    plot(Y_test(n, :),'DisplayName','standardizedYtrain2');
    hold on;
    plot(testY_Pred(n, :),'DisplayName','YPred');
    legend('真値', '推論値')
    grid on
    hold off;
    if n == 1
        title(["【検証：SOC】データ番号："+ num2str(n)]);
        xlabel('Index [-]')
        ylabel('SOC [%]')
    elseif n == 2
        title(["【検証：Voltage】データ番号："+ num2str(n)]);
        xlabel('Index [-]')
        ylabel('Voltage [V]')
    end
end
end