🐞 Bug Description
When using TorchSharp's Module system with inheritance, if a derived class overrides a submodule (e.g., replaces a Sequential defined in the base class), calling RegisterComponents() again in the derived class does not correctly reflect the new structure in the state_dict(). The state dictionary still returns the parameters of the original module registered in the base class.
✅ Expected Behavior
When creating two C2 objects with different configurations (one using a regular Conv, the other using a custom DSConv), the state_dict() should reflect the actual registered submodules.
📎 Reproduction Code (Minimal Example)
public class DSConv : Module<Tensor, Tensor>
{
private readonly Conv2d dw;
private readonly Conv2d pw;
private readonly BatchNorm2d bn;
public DSConv(int inChannels, int outChannels, int c):
base(nameof(DSConv))
{
RegisterComponents();
}
public override Tensor forward(Tensor input)
{
return input;
}
}
public class Conv : Module<Tensor, Tensor>
{
private readonly Conv2d conv;
private readonly BatchNorm2d bn;
public Conv(int inChannels, int outChannels, int c) :
base(nameof(Conv))
{
RegisterComponents();
}
public override Tensor forward(Tensor input)
{
return input;
}
}
public class C1 : Module<Tensor, Tensor>
{
public Sequential m;
public C1(int inChannels, int outChannels) : base(nameof(C1))
{
m = Sequential();
m = m.append(new Conv(inChannels, outChannels, 3));
RegisterComponents();
}
public override Tensor forward(Tensor input)
{
return input;
}
}
public class C2 : C1
{
public C2(int inChannels, int outChannels, bool dsc3k)
: base(inChannels, outChannels)
{
// new a Sequential to replace the existing m
m = Sequential();
if (dsc3k)
{
m.append(new Conv(inChannels, outChannels, 3));
}
else
{
m.append(new DSConv(inChannels, outChannels, 3));
}
RegisterComponents();
}
}
class Program
{
static void Main(string[] args)
{
var c2_1 = new C2(3, 3, true);
var c2_2 = new C2(3, 3, false);
c2_1.state_dict().Keys.ToList().ForEach(k => Console.WriteLine(k));
Console.WriteLine("======");
c2_2.state_dict().Keys.ToList().ForEach(k => Console.WriteLine(k));
}
}
📌 Actual Output
m.0.conv.weight
m.0.bn.weight
m.0.bn.bias
m.0.bn.running_mean
m.0.bn.running_var
m.0.bn.num_batches_tracked
======
m.0.conv.weight
m.0.bn.weight
m.0.bn.bias
m.0.bn.running_mean
m.0.bn.running_var
m.0.bn.num_batches_tracked
✅ Expected Output
m.0.conv.weight
m.0.bn.weight
m.0.bn.bias
m.0.bn.running_mean
m.0.bn.running_var
m.0.bn.num_batches_tracked
======
m.0.dw.weight
m.0.pw.weight
m.0.bn.weight
m.0.bn.bias
m.0.bn.running_mean
m.0.bn.running_var
m.0.bn.num_batches_tracked
🔍 Root Cause
In the base class C1, RegisterComponents() registers the initial version of m. In the derived class C2, we override the m field with a new Sequential, but even after calling RegisterComponents() again, the old components from C1 remain in the module registry.
This leads to an incorrect state_dict() that still reflects the structure of the base class, not the new submodules created in the derived class.
🐞 Bug Description
When using TorchSharp's Module system with inheritance, if a derived class overrides a submodule (e.g., replaces a Sequential defined in the base class), calling RegisterComponents() again in the derived class does not correctly reflect the new structure in the state_dict(). The state dictionary still returns the parameters of the original module registered in the base class.
✅ Expected Behavior
When creating two C2 objects with different configurations (one using a regular Conv, the other using a custom DSConv), the state_dict() should reflect the actual registered submodules.
📎 Reproduction Code (Minimal Example)
public class DSConv : Module<Tensor, Tensor> { private readonly Conv2d dw; private readonly Conv2d pw; private readonly BatchNorm2d bn; public DSConv(int inChannels, int outChannels, int c): base(nameof(DSConv)) { RegisterComponents(); } public override Tensor forward(Tensor input) { return input; } } public class Conv : Module<Tensor, Tensor> { private readonly Conv2d conv; private readonly BatchNorm2d bn; public Conv(int inChannels, int outChannels, int c) : base(nameof(Conv)) { RegisterComponents(); } public override Tensor forward(Tensor input) { return input; } } public class C1 : Module<Tensor, Tensor> { public Sequential m; public C1(int inChannels, int outChannels) : base(nameof(C1)) { m = Sequential(); m = m.append(new Conv(inChannels, outChannels, 3)); RegisterComponents(); } public override Tensor forward(Tensor input) { return input; } } public class C2 : C1 { public C2(int inChannels, int outChannels, bool dsc3k) : base(inChannels, outChannels) { // new a Sequential to replace the existing m m = Sequential(); if (dsc3k) { m.append(new Conv(inChannels, outChannels, 3)); } else { m.append(new DSConv(inChannels, outChannels, 3)); } RegisterComponents(); } } class Program { static void Main(string[] args) { var c2_1 = new C2(3, 3, true); var c2_2 = new C2(3, 3, false); c2_1.state_dict().Keys.ToList().ForEach(k => Console.WriteLine(k)); Console.WriteLine("======"); c2_2.state_dict().Keys.ToList().ForEach(k => Console.WriteLine(k)); } }📌 Actual Output
✅ Expected Output
🔍 Root Cause
In the base class C1, RegisterComponents() registers the initial version of m. In the derived class C2, we override the m field with a new Sequential, but even after calling RegisterComponents() again, the old components from C1 remain in the module registry.
This leads to an incorrect state_dict() that still reflects the structure of the base class, not the new submodules created in the derived class.