@@ -22,6 +22,7 @@ limitations under the License.
22
22
using System . Data ;
23
23
using System . Diagnostics ;
24
24
using System . Linq ;
25
+ using System . Reflection ;
25
26
using Tensorflow . Keras . ArgsDefinition ;
26
27
using Tensorflow . Keras . Engine ;
27
28
using Tensorflow . Keras . Layers ;
@@ -58,59 +59,32 @@ public static JObject serialize_keras_object(IKerasConfigable instance)
58
59
59
60
public static Layer deserialize_keras_object ( string class_name , JToken config )
60
61
{
61
- return class_name switch
62
- {
63
- "Sequential" => new Sequential ( config . ToObject < SequentialArgs > ( ) ) ,
64
- "InputLayer" => new InputLayer ( config . ToObject < InputLayerArgs > ( ) ) ,
65
- "Flatten" => new Flatten ( config . ToObject < FlattenArgs > ( ) ) ,
66
- "ELU" => new ELU ( config . ToObject < ELUArgs > ( ) ) ,
67
- "Dense" => new Dense ( config . ToObject < DenseArgs > ( ) ) ,
68
- "Softmax" => new Softmax ( config . ToObject < SoftmaxArgs > ( ) ) ,
69
- "Conv2D" => new Conv2D ( config . ToObject < Conv2DArgs > ( ) ) ,
70
- "BatchNormalization" => new BatchNormalization ( config . ToObject < BatchNormalizationArgs > ( ) ) ,
71
- "MaxPooling2D" => new MaxPooling2D ( config . ToObject < MaxPooling2DArgs > ( ) ) ,
72
- "Dropout" => new Dropout ( config . ToObject < DropoutArgs > ( ) ) ,
73
- _ => throw new NotImplementedException ( $ "The deserialization of <{ class_name } > has not been supported. Usually it's a miss during the development. " +
74
- $ "Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues")
75
- } ;
62
+ var argType = Assembly . Load ( "Tensorflow.Binding" ) . GetType ( $ "Tensorflow.Keras.ArgsDefinition.{ class_name } Args") ;
63
+ var deserializationMethod = typeof ( JToken ) . GetMethods ( BindingFlags . Instance | BindingFlags . Public )
64
+ . Single ( x => x . Name == "ToObject" && x . IsGenericMethodDefinition && x . GetParameters ( ) . Count ( ) == 0 ) ;
65
+ var deserializationGenericMethod = deserializationMethod . MakeGenericMethod ( argType ) ;
66
+ var args = deserializationGenericMethod . Invoke ( config , null ) ;
67
+ var layer = Assembly . Load ( "Tensorflow.Keras" ) . CreateInstance ( $ "Tensorflow.Keras.Layers.{ class_name } ", true , BindingFlags . Default , null , new object [ ] { args } , null , null ) ;
68
+ Debug . Assert ( layer is Layer ) ;
69
+ return layer as Layer ;
76
70
}
77
71
78
72
public static Layer deserialize_keras_object ( string class_name , LayerArgs args )
79
73
{
80
- return class_name switch
81
- {
82
- "Sequential" => new Sequential ( args as SequentialArgs ) ,
83
- "InputLayer" => new InputLayer ( args as InputLayerArgs ) ,
84
- "Flatten" => new Flatten ( args as FlattenArgs ) ,
85
- "ELU" => new ELU ( args as ELUArgs ) ,
86
- "Dense" => new Dense ( args as DenseArgs ) ,
87
- "Softmax" => new Softmax ( args as SoftmaxArgs ) ,
88
- "Conv2D" => new Conv2D ( args as Conv2DArgs ) ,
89
- "BatchNormalization" => new BatchNormalization ( args as BatchNormalizationArgs ) ,
90
- "MaxPooling2D" => new MaxPooling2D ( args as MaxPooling2DArgs ) ,
91
- "Dropout" => new Dropout ( args as DropoutArgs ) ,
92
- _ => throw new NotImplementedException ( $ "The deserialization of <{ class_name } > has not been supported. Usually it's a miss during the development. " +
93
- $ "Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues")
94
- } ;
74
+ var layer = Assembly . Load ( "Tensorflow.Keras" ) . CreateInstance ( $ "Tensorflow.Keras.Layers.{ class_name } ", true , BindingFlags . Default , null , new object [ ] { args } , null , null ) ;
75
+ Debug . Assert ( layer is Layer ) ;
76
+ return layer as Layer ;
95
77
}
96
78
97
- public static LayerArgs ? deserialize_layer_args ( string class_name , JToken config )
79
+ public static LayerArgs deserialize_layer_args ( string class_name , JToken config )
98
80
{
99
- return class_name switch
100
- {
101
- "Sequential" => config . ToObject < SequentialArgs > ( ) ,
102
- "InputLayer" => config . ToObject < InputLayerArgs > ( ) ,
103
- "Flatten" => config . ToObject < FlattenArgs > ( ) ,
104
- "ELU" => config . ToObject < ELUArgs > ( ) ,
105
- "Dense" => config . ToObject < DenseArgs > ( ) ,
106
- "Softmax" => config . ToObject < SoftmaxArgs > ( ) ,
107
- "Conv2D" => config . ToObject < Conv2DArgs > ( ) ,
108
- "BatchNormalization" => config . ToObject < BatchNormalizationArgs > ( ) ,
109
- "MaxPooling2D" => config . ToObject < MaxPooling2DArgs > ( ) ,
110
- "Dropout" => config . ToObject < DropoutArgs > ( ) ,
111
- _ => throw new NotImplementedException ( $ "The deserialization of <{ class_name } > has not been supported. Usually it's a miss during the development. " +
112
- $ "Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues")
113
- } ;
81
+ var argType = Assembly . Load ( "Tensorflow.Binding" ) . GetType ( $ "Tensorflow.Keras.ArgsDefinition.{ class_name } Args") ;
82
+ var deserializationMethod = typeof ( JToken ) . GetMethods ( BindingFlags . Instance | BindingFlags . Public )
83
+ . Single ( x => x . Name == "ToObject" && x . IsGenericMethodDefinition && x . GetParameters ( ) . Count ( ) == 0 ) ;
84
+ var deserializationGenericMethod = deserializationMethod . MakeGenericMethod ( argType ) ;
85
+ var args = deserializationGenericMethod . Invoke ( config , null ) ;
86
+ Debug . Assert ( args is LayerArgs ) ;
87
+ return args as LayerArgs ;
114
88
}
115
89
116
90
public static ModelConfig deserialize_model_config ( JToken json )
0 commit comments