我正在尝试根据新数据重新训练模型。 神经网络学习算法允许根据文档进行再训练。
//формируем данные для обучения из коллекции
var trainingDataView =
mlContext.Data.LoadFromEnumerable(trainDataFromDb.TrainData);
//формируем данные в формате для обучения
var trainingPipeline = mlContext.Transforms.Conversion.MapValueToKey(inputColumnName: "IdCategory", outputColumnName: "Label")
.Append(mlContext.Transforms.Text.FeaturizeText(inputColumnName: "DescriptionProduct", outputColumnName: "DescriptionProductFeaturized"))
.Append(mlContext.Transforms.Concatenate("Features", "DescriptionProductFeaturized"))
.Append(mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy())
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
//обучаем нейросеть
ITransformer trainedModel = trainingPipeline.Fit(trainingDataView);
//используем FileStream
using (var fileStream = new FileStream(FullPathModelNeural, FileMode.Create, FileAccess.Write, FileShare.Write))
{
//сохраняем обученную модель нейронной сети
mlContext.Model.Save(trainedModel, trainingDataView.Schema, fileStream);
}
神经网络处理分配给它的任务。
但只要我想重新训练它,我就会使用以下代码:
//загружаем обученную модель
var trainedModelFromFile = mlContext.Model.Load(FullPathModelNeural, out var modelSchema);
//извлекаем параметры обученной модели
var originalModelParameters = ((ISingleFeaturePredictionTransformer<object>)trainedModelFromFile).Model as MaximumEntropyModelParameters;
//формируем данные для обучения из коллекции
var trainingDataView = mlContext.Data.LoadFromEnumerable(trainDataFromDb.TrainData);
//дообучаем модель с учетом новых данных
var retrainedModel = mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy().Fit(trainingDataView, originalModelParameters);
//используем FileStream
using (var fileStream = new FileStream(FullPathModelNeural, FileMode.Create, FileAccess.Write, FileShare.Write))
{
//сохраняем обученную модель нейронной сети
mlContext.Model.Save(retrainedModel, modelSchema, fileStream);
}
我得到一个异常(exception)。无法将对象类型 Microsoft.ML.Data.TransformerChain [Microsoft.ML.ITransformer] 转换为类型 Microsoft.ML.ISingleFeaturePredictionTransformer [System.Object] 在这行代码中
var originalModelParameters = ((ISingleFeaturePredictionTransformer<object>)trainedModelFromFile).Model as MaximumEntropyModelParameters;
请帮助我了解如何正确地重新训练神经网络
最佳答案
问题可能在于您如何保存模型,即您保存的不仅仅是 Estimator。
可能有帮助的类似示例:
IDataView inputDataView = mlContext.Data.LoadFromTextFile<ModelInput>("initialdata.csv");
// Prepare input data to a form consumable by a machine learning model
var inputDataPreparer = mlContext.Transforms.Text.FeaturizeText("Features", nameof(ModelInput.Comment)).Fit(inputDataView);
// Create a training algorithm
var trainer = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression();
var transformedData = inputDataPreparer.Transform(inputDataSplit.TrainSet);
var trainedModel = trainer.Fit(transformedData);
//Save the model
mlContext.Model.Save(trainedModel, inputDataView.Schema, "modelfile.zip");
//Save the input data preparing pipeline");
mlContext.Model.Save(inputDataPreparer, inputDataView.Schema, "preparePipeline.zip");
// -- Later retraining step --
// Load data preparation pipeline
ITransformer dataPrepPipeline = mlContext.Model.Load("preparePipeline.zip", out _);
// Load trained model
ITransformer originalModel = mlContext.Model.Load("modelfile.zip", out _);
// Extract trained model parameters
var originalModelParameters = ((ISingleFeaturePredictionTransformer<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>)originalModel).Model.SubModel;
IDataView retrainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>("MOREDATA.CSV");
var newData = dataPrepPipeline.Transform(retrainingDataView);
IDataView transformedNewData = dataPrepPipeline.Transform(newData);
// Retrain model
var retrainedModel = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression().Fit(transformedNewData, originalModelParameters);
//Save complete pipeline for usage
var completeRetrainedPipeline = dataPrepPipeline.Append(retrainedModel);
mlContext.Model.Save(completeRetrainedPipeline, transformedNewData.Schema, "retrainedModel.zip);
https://stackoverflow.com/questions/58093098/
相关文章:
c# - 配置生成器 - 无法加载程序集 - .Net Framework
visual-studio-code - 如何使用 Visual Studio Code + VSC
microsoft-graph-api - 更新和删除日历事件而不向与会者发送通知
php - 未定义的方法 Laravel\Lumen\Application::booted()
ruby-on-rails - 仅导入主文件时,SASS 变量不起作用
reactjs - 如何使用带有反应最终形式的自定义 radio 组件?
java - 尝试通过运行 `mvn test` 来运行测试时 JVM 崩溃