다음을 통해 공유


TensorFlowEstimator 클래스

정의

다음 TensorFlowTransformer 두 가지 시나리오에서 사용됩니다.

  1. 미리 학습된 TensorFlow 모델을 사용한 채점: 이 모드에서 변환은 미리 학습된 Tensorflow 모델에서 숨겨진 계층의 값을 추출하고 출력을 ML.Net 파이프라인의 기능으로 사용합니다.
  2. TensorFlow 모델 재학습: 이 모드에서 변환은 ML.Net 파이프라인을 통해 전달된 사용자 데이터를 사용하여 TensorFlow 모델을 재학습합니다. 모델이 학습되면 출력을 점수 매기기 기능으로 사용할 수 있습니다.
public sealed class TensorFlowEstimator : Microsoft.ML.IEstimator<Microsoft.ML.Transforms.TensorFlowTransformer>
type TensorFlowEstimator = class
    interface IEstimator<TensorFlowTransformer>
Public NotInheritable Class TensorFlowEstimator
Implements IEstimator(Of TensorFlowTransformer)
상속
TensorFlowEstimator
구현

설명

TensorFlowTransform은 미리 학습된 Tensorflow 모델을 사용하여 지정된 출력을 추출합니다. 필요에 따라 사용자 데이터에 대한 TensorFlow 모델을 추가로 재학습하여 사용자 데이터에 대한 모델 매개 변수를 조정할 수 있습니다("전송 학습"으로도 알려짐).

채점의 경우 변환은 미리 학습된 Tensorflow 모델, 입력 노드의 이름 및 추출할 값을 가진 출력 노드의 이름을 입력으로 사용합니다. 또한 재학습을 위해 변환에는 TensorFlow 그래프의 최적화 작업 이름, 그래프의 학습 속도 작업 이름 및 해당 값, 손실 및 성능 메트릭을 계산하기 위한 그래프의 작업 이름 등과 같은 학습 관련 매개 변수가 필요합니다.

이 변환을 수행하려면 Microsoft.ML.TensorFlow nuget을 설치해야 합니다. TensorFlowTransform에는 입력, 출력, 데이터 처리 및 재학습과 관련하여 다음과 같은 가정이 있습니다.

  1. 입력 모델의 경우 현재 TensorFlowTransform은 Frozen 모델 형식과 SavedModel 형식을 모두 지원합니다. 그러나 모델의 재학습은 SavedModel 형식에 대해서만 가능합니다. 검사점 형식은 현재 텐서플로 C-API의 로딩 지원이 부족하여 점수 매기기 또는 재학습에 지원되지 않습니다.
  2. 변환은 한 번에 하나의 예제만 채점할 수 있습니다. 그러나 재학습은 일괄 처리로 수행할 수 있습니다.
  3. 고급 전송 학습/미세 조정 시나리오(예: 네트워크에 더 많은 계층 추가, 입력 모양 변경, 재학습 프로세스 중에 업데이트할 필요가 없는 계층 고정 등)는 TensorFlow C-API를 사용하여 모델 내에서 네트워크/그래프 조작에 대한 지원이 부족하기 때문에 현재 불가능합니다.
  4. 입력 열의 이름은 TensorFlow 모델의 입력 이름과 일치해야 합니다.
  5. 각 출력 열의 이름은 TensorFlow 그래프의 작업 중 하나와 일치해야 합니다.
  6. 현재 double, float, long, int, short, sbyte, ulong, uint, ushort, byte 및 bool은 입력/출력에 허용되는 데이터 형식입니다.
  7. 성공하면 변환에 지정된 각 출력 열에 IDataView 해당하는 새 열이 도입됩니다.

TensorFlow 모델의 입력 및 출력은 summarize_graph 도구를 사용하여 GetModelSchema() 가져올 수 있습니다.

메서드

Fit(IDataView)

를 학습하고 반환합니다 TensorFlowTransformer.

GetOutputSchema(SchemaShape)

변환기에서 SchemaShape 생성할 스키마를 반환합니다. 파이프라인에서 스키마 전파 및 확인에 사용됩니다.

확장 메서드

AppendCacheCheckpoint<TTrans>(IEstimator<TTrans>, IHostEnvironment)

추정기 체인에 '캐싱 검사점'을 추가합니다. 이렇게 하면 다운스트림 추정기가 캐시된 데이터에 대해 학습됩니다. 여러 데이터 전달을 수행하는 트레이너 앞에 캐싱 검사점이 있는 것이 좋습니다.

WithOnFitDelegate<TTransformer>(IEstimator<TTransformer>, Action<TTransformer>)

추정기가 지정된 경우 호출된 대리 Fit(IDataView) 자를 호출할 래핑 개체를 반환합니다. 예측 도구가 적합한 항목에 대한 정보를 반환하는 것이 중요한 경우가 많습니다. 따라서 Fit(IDataView) 메서드는 일반 ITransformer개체가 아닌 구체적으로 형식화된 개체를 반환합니다. 그러나 동시에 IEstimator<TTransformer> 개체가 많은 파이프라인으로 형성되는 경우가 많으므로 변환기를 가져올 추정기가 이 체인의 어딘가에 묻혀 있는 위치를 통해 EstimatorChain<TLastTransformer> 추정기 체인을 빌드해야 할 수 있습니다. 이 시나리오에서는 이 메서드를 통해 fit이 호출되면 호출될 대리자를 연결할 수 있습니다.

적용 대상