.net - ML .Net 无法重新训练神经网络

我正在尝试根据新数据重新训练模型。 神经网络学习算法允​​许根据文档进行再训练。

//формируем данные для обучения из коллекции
var trainingDataView = 
mlContext.Data.LoadFromEnumerable(trainDataFromDb.TrainData);

//формируем данные в формате для обучения
var trainingPipeline = mlContext.Transforms.Conversion.MapValueToKey(inputColumnName: "IdCategory", outputColumnName: "Label")
    .Append(mlContext.Transforms.Text.FeaturizeText(inputColumnName: "DescriptionProduct", outputColumnName: "DescriptionProductFeaturized"))
    .Append(mlContext.Transforms.Concatenate("Features", "DescriptionProductFeaturized"))
    .Append(mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy())
    .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));

//обучаем нейросеть
ITransformer trainedModel = trainingPipeline.Fit(trainingDataView);

//используем FileStream
using (var fileStream = new FileStream(FullPathModelNeural, FileMode.Create, FileAccess.Write, FileShare.Write))
{
    //сохраняем обученную модель нейронной сети
    mlContext.Model.Save(trainedModel, trainingDataView.Schema, fileStream);
}

神经网络处理分配给它的任务。

但只要我想重新训练它,我就会使用以下代码:

//загружаем обученную модель
var trainedModelFromFile = mlContext.Model.Load(FullPathModelNeural, out var modelSchema);

//извлекаем параметры обученной модели
var originalModelParameters = ((ISingleFeaturePredictionTransformer<object>)trainedModelFromFile).Model as MaximumEntropyModelParameters;

//формируем данные для обучения из коллекции
var trainingDataView = mlContext.Data.LoadFromEnumerable(trainDataFromDb.TrainData);

//дообучаем модель с учетом новых данных
var retrainedModel = mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy().Fit(trainingDataView, originalModelParameters);

//используем FileStream
using (var fileStream = new FileStream(FullPathModelNeural, FileMode.Create, FileAccess.Write, FileShare.Write))
{
    //сохраняем обученную модель нейронной сети
    mlContext.Model.Save(retrainedModel, modelSchema, fileStream);
}

我得到一个异常(exception)。无法将对象类型 Microsoft.ML.Data.TransformerChain [Microsoft.ML.ITransformer] 转换为类型 Microsoft.ML.ISingleFeaturePredictionTransformer [System.Object] 在这行代码中

var originalModelParameters = ((ISingleFeaturePredictionTransformer<object>)trainedModelFromFile).Model as MaximumEntropyModelParameters;

请帮助我了解如何正确地重新训练神经网络

最佳答案

问题可能在于您如何保存模型,即您保存的不仅仅是 Estimator。

可能有帮助的类似示例:

IDataView inputDataView = mlContext.Data.LoadFromTextFile<ModelInput>("initialdata.csv");

// Prepare input data to a form consumable by a machine learning model
var inputDataPreparer = mlContext.Transforms.Text.FeaturizeText("Features", nameof(ModelInput.Comment)).Fit(inputDataView); 

// Create a training algorithm 
var trainer = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression();

var transformedData = inputDataPreparer.Transform(inputDataSplit.TrainSet);
var trainedModel = trainer.Fit(transformedData);

//Save the model
mlContext.Model.Save(trainedModel, inputDataView.Schema, "modelfile.zip");

//Save the input data preparing pipeline");
mlContext.Model.Save(inputDataPreparer, inputDataView.Schema, "preparePipeline.zip");


// -- Later retraining step --

// Load data preparation pipeline
ITransformer dataPrepPipeline = mlContext.Model.Load("preparePipeline.zip", out _);

// Load trained model
ITransformer originalModel = mlContext.Model.Load("modelfile.zip", out _);

// Extract trained model parameters
var originalModelParameters = ((ISingleFeaturePredictionTransformer<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>)originalModel).Model.SubModel; 

IDataView retrainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>("MOREDATA.CSV");

var newData = dataPrepPipeline.Transform(retrainingDataView);
IDataView transformedNewData = dataPrepPipeline.Transform(newData);

// Retrain model
var retrainedModel = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression().Fit(transformedNewData, originalModelParameters);

//Save complete pipeline for usage
var completeRetrainedPipeline = dataPrepPipeline.Append(retrainedModel);
mlContext.Model.Save(completeRetrainedPipeline, transformedNewData.Schema, "retrainedModel.zip);

https://stackoverflow.com/questions/58093098/

相关文章:

python - Airflow 外部任务传感器卡​​住

c# - 配置生成器 - 无法加载程序集 - .Net Framework

visual-studio-code - 如何使用 Visual Studio Code + VSC

microsoft-graph-api - 更新和删除日历事件而不向与会者发送通知

php - 未定义的方法 Laravel\Lumen\Application::booted()

ruby-on-rails - 仅导入主文件时,SASS 变量不起作用

reactjs - 如何使用带有反应最终形式的自定义 radio 组件?

java - 尝试通过运行 `mvn test` 来运行测试时 JVM 崩溃

react-native - 在 native react 中使用相机捕获方框内的区域

azure - 作为 Azure DevOps 发布管道的一部分从应​​用服务中删除文件